2525 LiteralType ,
2626 OpenAPIScope ,
2727 PythonVersion ,
28- cached_property ,
2928 snooper_to_methods ,
3029)
3130from datamodel_code_generator .imports import Import , Imports
3231from datamodel_code_generator .model import DataModel , DataModelFieldBase
3332from datamodel_code_generator .model import pydantic as pydantic_model
34- from datamodel_code_generator .model .pydantic import DataModelField
33+ from datamodel_code_generator .model .pydantic import CustomRootType , DataModelField
3534from datamodel_code_generator .parser .jsonschema import JsonSchemaObject
3635from datamodel_code_generator .parser .openapi import MediaObject
3736from datamodel_code_generator .parser .openapi import OpenAPIParser as OpenAPIModelParser
4342 ResponseObject ,
4443)
4544from datamodel_code_generator .types import DataType , DataTypeManager , StrictTypes
46- from pydantic import BaseModel
45+ from datamodel_code_generator .util import cached_property
46+ from pydantic import BaseModel , ValidationInfo
4747
4848RE_APPLICATION_JSON_PATTERN : Pattern [str ] = re .compile (r'^application/.*json$' )
4949
@@ -72,7 +72,7 @@ def __get_validators__(cls) -> Any:
7272 yield cls .validate
7373
7474 @classmethod
75- def validate (cls , v : Any ) -> Any :
75+ def validate (cls , v : Any , info : ValidationInfo ) -> Any :
7676 return cls (v )
7777
7878 @property
@@ -91,8 +91,8 @@ def camelcase(self) -> str:
9191class Argument (CachedPropertyModel ):
9292 name : UsefulStr
9393 type_hint : UsefulStr
94- default : Optional [UsefulStr ]
95- default_value : Optional [UsefulStr ]
94+ default : Optional [UsefulStr ] = None
95+ default_value : Optional [UsefulStr ] = None
9696 required : bool
9797
9898 def __str__ (self ) -> str :
@@ -108,20 +108,20 @@ def argument(self) -> str:
108108class Operation (CachedPropertyModel ):
109109 method : UsefulStr
110110 path : UsefulStr
111- operationId : Optional [UsefulStr ]
112- description : Optional [str ]
113- summary : Optional [str ]
111+ operationId : Optional [UsefulStr ] = None
112+ description : Optional [str ] = None
113+ summary : Optional [str ] = None
114114 parameters : List [Dict [str , Any ]] = []
115115 responses : Dict [UsefulStr , Any ] = {}
116116 deprecated : bool = False
117117 imports : List [Import ] = []
118118 security : Optional [List [Dict [str , List [str ]]]] = None
119- tags : Optional [List [str ]]
119+ tags : Optional [List [str ]] = []
120120 arguments : str = ''
121121 snake_case_arguments : str = ''
122122 request : Optional [Argument ] = None
123123 response : str = ''
124- additional_responses : Dict [str , Dict [str , str ]] = {}
124+ additional_responses : Dict [Union [ str , int ] , Dict [str , str ]] = {}
125125 return_type : str = ''
126126
127127 @cached_property
@@ -245,16 +245,22 @@ def parse_info(self) -> Optional[Dict[str, Any]]:
245245 result ['servers' ] = servers
246246 return result or None
247247
248- def parse_parameters (self , parameters : ParameterObject , path : List [str ]) -> None :
249- super ().parse_parameters (parameters , path )
250- self ._temporary_operation ['_parameters' ].append (parameters )
248+ def parse_all_parameters (
249+ self ,
250+ name : str ,
251+ parameters : List [Union [ReferenceObject , ParameterObject ]],
252+ path : List [str ],
253+ ) -> None :
254+ super ().parse_all_parameters (name , parameters , path )
255+ self ._temporary_operation ['_parameters' ].extend (parameters )
251256
252257 def get_parameter_type (
253258 self ,
254- parameters : ParameterObject ,
259+ parameters : Union [ ReferenceObject , ParameterObject ] ,
255260 snake_case : bool ,
256261 path : List [str ],
257262 ) -> Optional [Argument ]:
263+ parameters = self .resolve_object (parameters , ParameterObject )
258264 orig_name = parameters .name
259265 if snake_case :
260266 name = stringcase .snakecase (parameters .name )
@@ -274,7 +280,10 @@ def get_parameter_type(
274280 if not data_type :
275281 if not schema :
276282 schema = parameters .schema_
283+ if schema is None :
284+ raise RuntimeError ("schema is None" ) # pragma: no cover
277285 data_type = self .parse_schema (name , schema , [* path , name ])
286+ data_type = self ._collapse_root_model (data_type )
278287 if not schema :
279288 return None
280289
@@ -297,9 +306,11 @@ def get_parameter_type(
297306 default = repr (schema .default ) if schema .has_default else None
298307 self .imports_for_fastapi .append (field .imports )
299308 self .data_types .append (field .data_type )
309+ if field .name is None :
310+ raise RuntimeError ("field.name is None" ) # pragma: no cover
300311 return Argument (
301- name = field .name ,
302- type_hint = field .type_hint ,
312+ name = UsefulStr ( field .name ) ,
313+ type_hint = UsefulStr ( field .type_hint ) ,
303314 default = default , # type: ignore
304315 default_value = schema .default ,
305316 required = field .required ,
@@ -361,11 +372,12 @@ def parse_request_body(
361372 data_type = self .parse_schema (
362373 name , media_obj .schema_ , [* path , media_type ]
363374 )
375+ data_type = self ._collapse_root_model (data_type )
364376 arguments .append (
365377 # TODO: support multiple body
366378 Argument (
367379 name = 'body' , # type: ignore
368- type_hint = data_type .type_hint ,
380+ type_hint = UsefulStr ( data_type .type_hint ) ,
369381 required = request_body .required ,
370382 )
371383 )
@@ -406,17 +418,18 @@ def parse_request_body(
406418 )
407419 self ._temporary_operation ['_request' ] = arguments [0 ] if arguments else None
408420
409- def parse_responses (
421+ def parse_responses ( # type: ignore[override]
410422 self ,
411423 name : str ,
412424 responses : Dict [str , Union [ResponseObject , ReferenceObject ]],
413425 path : List [str ],
414- ) -> Dict [str , Dict [str , DataType ]]:
415- data_types = super ().parse_responses (name , responses , path )
426+ ) -> Dict [Union [ str , int ] , Dict [str , DataType ]]:
427+ data_types = super ().parse_responses (name , responses , path ) # type: ignore[arg-type]
416428 status_code_200 = data_types .get ('200' )
417429 if status_code_200 :
418430 data_type = list (status_code_200 .values ())[0 ]
419431 if data_type :
432+ data_type = self ._collapse_root_model (data_type )
420433 self .data_types .append (data_type )
421434 else :
422435 data_type = DataType (type = 'None' )
@@ -466,3 +479,24 @@ def parse_operation(
466479 path = f'/{ path_name } ' , # type: ignore
467480 method = method , # type: ignore
468481 )
482+
483+ def _collapse_root_model (self , data_type : DataType ) -> DataType :
484+ reference = data_type .reference
485+ import functools
486+
487+ if not (
488+ reference
489+ and (
490+ len (reference .children ) == 1
491+ or functools .reduce (lambda a , b : a == b , reference .children )
492+ )
493+ ):
494+ return data_type
495+ source = reference .source
496+ if not isinstance (source , CustomRootType ):
497+ return data_type
498+ data_type .remove_reference ()
499+ data_type = source .fields [0 ].data_type
500+ if source in self .results :
501+ self .results .remove (source )
502+ return data_type
0 commit comments