|
| 1 | +import logging |
| 2 | +from typing import Any, List, Set, Type, TypeVar |
| 3 | + |
| 4 | +from pydantic import BaseModel |
| 5 | +from pydantic.schema import schema |
| 6 | + |
| 7 | +from . import Components, OpenAPI, Reference, Schema |
| 8 | + |
| 9 | +PydanticType = TypeVar("PydanticType", bound=BaseModel) |
| 10 | +ref_prefix = "#/components/schemas/" |
| 11 | + |
| 12 | + |
| 13 | +class PydanticSchema(Schema): |
| 14 | + """Special `Schema` class to indicate a reference from pydantic class""" |
| 15 | + |
| 16 | + schema_class: Type[PydanticType] = ... |
| 17 | + """the class that is used for generate the schema""" |
| 18 | + |
| 19 | + |
| 20 | +def construct_open_api_with_schema_class( |
| 21 | + open_api: OpenAPI, |
| 22 | + schema_classes: List[Type[PydanticType]] = None, |
| 23 | + scan_for_pydantic_schema_reference: bool = True, |
| 24 | + by_alias: bool = True, |
| 25 | +) -> OpenAPI: |
| 26 | + """ |
| 27 | + Construct a new OpenAPI object, with the use of pydantic classes to produce JSON schemas |
| 28 | +
|
| 29 | + :param open_api: the base `OpenAPI` object |
| 30 | + :param schema_classes: pydanitic classes that their schema will be used "#/components/schemas" values |
| 31 | + :param scan_for_pydantic_schema_reference: flag to indicate if scanning for `PydanticSchemaReference` class |
| 32 | + is needed for "#/components/schemas" value updates |
| 33 | + :param by_alias: construct schema by alias (default is True) |
| 34 | + :return: new OpenAPI object with "#/components/schemas" values updated. |
| 35 | + If there is no update in "#/components/schemas" values, the original `open_api` will be returned. |
| 36 | + """ |
| 37 | + new_open_api: OpenAPI = open_api.copy(deep=True) |
| 38 | + if scan_for_pydantic_schema_reference: |
| 39 | + extracted_schema_classes = _handle_pydantic_schema(new_open_api) |
| 40 | + if schema_classes: |
| 41 | + schema_classes = list({*schema_classes, *_handle_pydantic_schema(new_open_api)}) |
| 42 | + else: |
| 43 | + schema_classes = extracted_schema_classes |
| 44 | + |
| 45 | + if not schema_classes: |
| 46 | + return open_api |
| 47 | + |
| 48 | + schema_classes.sort(key=lambda x: x.__name__) |
| 49 | + logging.debug(f"schema_classes{schema_classes}") |
| 50 | + |
| 51 | + # update new_open_api with new #/components/schemas |
| 52 | + schema_definitions = schema(schema_classes, by_alias=by_alias, ref_prefix=ref_prefix) |
| 53 | + if not new_open_api.components: |
| 54 | + new_open_api.components = Components() |
| 55 | + if new_open_api.components.schemas: |
| 56 | + for existing_key in new_open_api.components.schemas: |
| 57 | + if existing_key in schema_definitions.get("definitions"): |
| 58 | + logging.warning( |
| 59 | + f'"{existing_key}" already exists in {ref_prefix}. ' |
| 60 | + f'The value of "{ref_prefix}{existing_key}" will be overwritten.' |
| 61 | + ) |
| 62 | + new_open_api.components.schemas.update( |
| 63 | + {key: Schema.parse_obj(schema_dict) for key, schema_dict in schema_definitions.get("definitions").items()} |
| 64 | + ) |
| 65 | + else: |
| 66 | + new_open_api.components.schemas = { |
| 67 | + key: Schema.parse_obj(schema_dict) for key, schema_dict in schema_definitions.get("definitions").items() |
| 68 | + } |
| 69 | + return new_open_api |
| 70 | + |
| 71 | + |
| 72 | +def _handle_pydantic_schema(open_api: OpenAPI) -> List[Type[PydanticType]]: |
| 73 | + """ |
| 74 | + This function traverses the `OpenAPI` object and |
| 75 | +
|
| 76 | + 1. Replaces the `PydanticSchema` object with `Reference` object, with correct ref value; |
| 77 | + 2. Extracts the involved schema class from `PydanticSchema` object. |
| 78 | +
|
| 79 | + **This function will mutate the input `OpenAPI` object.** |
| 80 | +
|
| 81 | + :param open_api: the `OpenAPI` object to be traversed and mutated |
| 82 | + :return: a list of schema classes extracted from `PydanticSchema` objects |
| 83 | + """ |
| 84 | + |
| 85 | + pydantic_types: Set[Type[PydanticType]] = set() |
| 86 | + |
| 87 | + def _traverse(obj: Any): |
| 88 | + if isinstance(obj, BaseModel): |
| 89 | + fields = obj.__fields_set__ |
| 90 | + for field in fields: |
| 91 | + child_obj = obj.__getattribute__(field) |
| 92 | + if isinstance(child_obj, PydanticSchema): |
| 93 | + logging.debug(f"PydanticSchema found in {obj.__repr_name__()}: {child_obj}") |
| 94 | + obj.__setattr__(field, _construct_ref_obj(child_obj)) |
| 95 | + pydantic_types.add(child_obj.schema_class) |
| 96 | + else: |
| 97 | + _traverse(child_obj) |
| 98 | + elif isinstance(obj, list): |
| 99 | + for index, elem in enumerate(obj): |
| 100 | + if isinstance(elem, PydanticSchema): |
| 101 | + logging.debug(f"PydanticSchema found in list: {elem}") |
| 102 | + obj[index] = _construct_ref_obj(elem) |
| 103 | + pydantic_types.add(elem.schema_class) |
| 104 | + else: |
| 105 | + _traverse(elem) |
| 106 | + elif isinstance(obj, dict): |
| 107 | + for key, value in obj.items(): |
| 108 | + if isinstance(value, PydanticSchema): |
| 109 | + logging.debug(f"PydanticSchema found in dict: {value}") |
| 110 | + obj[key] = _construct_ref_obj(value) |
| 111 | + pydantic_types.add(value.schema_class) |
| 112 | + else: |
| 113 | + _traverse(value) |
| 114 | + |
| 115 | + _traverse(open_api) |
| 116 | + return list(pydantic_types) |
| 117 | + |
| 118 | + |
| 119 | +def _construct_ref_obj(pydantic_schema: PydanticSchema): |
| 120 | + ref_obj = Reference(ref=ref_prefix + pydantic_schema.schema_class.__name__) |
| 121 | + logging.debug(f"ref_obj={ref_obj}") |
| 122 | + return ref_obj |
0 commit comments