1
1
import logging
2
- from typing import Any , Generic , List , Optional , Set , Type , TypeVar
2
+ from typing import Any , Generic , List , Optional , Set , Type , TypeVar , cast
3
3
4
4
from pydantic import BaseModel
5
5
11
11
v1_schema ,
12
12
)
13
13
14
- from . import Components , OpenAPI , Reference , Schema
14
+ from . import Components , OpenAPI , Reference , Schema , schema_validate
15
15
16
16
logger = logging .getLogger (__name__ )
17
17
@@ -32,13 +32,16 @@ def get_mode(
32
32
) -> JsonSchemaMode :
33
33
"""Get the JSON schema mode for a model class.
34
34
35
- The mode can be either "serialization " or "validation ". In validation mode,
35
+ The mode can be either "validation " or "serialization ". In validation mode,
36
36
computed fields are dropped and optional fields remain optional. In
37
37
serialization mode, computed and optional fields are required.
38
38
"""
39
39
if not hasattr (cls , "model_config" ):
40
40
return default
41
- return cls .model_config .get ("json_schema_mode" , default )
41
+ mode = cls .model_config .get ("json_schema_mode" , default )
42
+ if mode not in ("validation" , "serialization" ):
43
+ raise ValueError (f"invalid json_schema_mode: { mode } " )
44
+ return cast (JsonSchemaMode , mode )
42
45
43
46
44
47
def construct_open_api_with_schema_class (
@@ -62,10 +65,8 @@ def construct_open_api_with_schema_class(
62
65
If there is no update in "#/components/schemas" values, the original
63
66
`open_api` will be returned.
64
67
"""
65
- if PYDANTIC_V2 :
66
- new_open_api = open_api .model_copy (deep = True )
67
- else :
68
- new_open_api = open_api .copy (deep = True )
68
+ copy_func = getattr (open_api , "model_copy" if PYDANTIC_V2 else "copy" )
69
+ new_open_api : OpenAPI = copy_func (deep = True )
69
70
70
71
if scan_for_pydantic_schema_reference :
71
72
extracted_schema_classes = _handle_pydantic_schema (new_open_api )
@@ -80,7 +81,7 @@ def construct_open_api_with_schema_class(
80
81
return open_api
81
82
82
83
schema_classes .sort (key = lambda x : x .__name__ )
83
- logger .debug (f "schema_classes{ schema_classes } " )
84
+ logger .debug ("schema_classes: %s" , schema_classes )
84
85
85
86
# update new_open_api with new #/components/schemas
86
87
if PYDANTIC_V2 :
@@ -94,7 +95,6 @@ def construct_open_api_with_schema_class(
94
95
schema_classes , by_alias = by_alias , ref_prefix = ref_prefix
95
96
)
96
97
97
- schema_validate = Schema .model_validate if PYDANTIC_V2 else Schema .parse_obj
98
98
if not new_open_api .components :
99
99
new_open_api .components = Components ()
100
100
if new_open_api .components .schemas :
@@ -111,6 +111,8 @@ def construct_open_api_with_schema_class(
111
111
}
112
112
)
113
113
else :
114
+ for key , schema_dict in schema_definitions [DEFS_KEY ].items ():
115
+ schema_validate (schema_dict )
114
116
new_open_api .components .schemas = {
115
117
key : schema_validate (schema_dict )
116
118
for key , schema_dict in schema_definitions [DEFS_KEY ].items ()
@@ -136,13 +138,13 @@ def _handle_pydantic_schema(open_api: OpenAPI) -> List[Type[BaseModel]]:
136
138
137
139
def _traverse (obj : Any ) -> None :
138
140
if isinstance (obj , BaseModel ):
139
- fields = obj .model_fields_set if PYDANTIC_V2 else obj .__fields_set__
141
+ fields = getattr (
142
+ obj , "model_fields_set" if PYDANTIC_V2 else "__fields_set__"
143
+ )
140
144
for field in fields :
141
145
child_obj = obj .__getattribute__ (field )
142
146
if isinstance (child_obj , PydanticSchema ):
143
- logger .debug (
144
- f"PydanticSchema found in { obj .__repr_name__ ()} : { child_obj } "
145
- )
147
+ logger .debug ("PydanticSchema found in %s: %s" , obj , child_obj )
146
148
obj .__setattr__ (field , _construct_ref_obj (child_obj ))
147
149
pydantic_types .add (child_obj .schema_class )
148
150
else :
@@ -169,6 +171,6 @@ def _traverse(obj: Any) -> None:
169
171
170
172
171
173
def _construct_ref_obj (pydantic_schema : PydanticSchema [PydanticType ]) -> Reference :
172
- ref_obj = Reference (ref = ref_prefix + pydantic_schema .schema_class .__name__ )
174
+ ref_obj = Reference (** { "$ ref" : ref_prefix + pydantic_schema .schema_class .__name__ } )
173
175
logger .debug (f"ref_obj={ ref_obj } " )
174
176
return ref_obj
0 commit comments