22
33import pathlib
44import re
5+ from functools import cached_property
56from typing import (
67 Any ,
78 Callable ,
4243 ResponseObject ,
4344)
4445from datamodel_code_generator .types import DataType , DataTypeManager , StrictTypes
45- from datamodel_code_generator .util import cached_property
4646from pydantic import BaseModel , ValidationInfo
4747
4848RE_APPLICATION_JSON_PATTERN : Pattern [str ] = re .compile (r'^application/.*json$' )
@@ -93,16 +93,43 @@ class Argument(CachedPropertyModel):
9393 type_hint : UsefulStr
9494 default : Optional [UsefulStr ] = None
9595 default_value : Optional [UsefulStr ] = None
96+ field : Union [DataModelField , list [DataModelField ], None ] = None
9697 required : bool
9798
9899 def __str__ (self ) -> str :
99100 return self .argument
100101
101- @cached_property
102+ @property
102103 def argument (self ) -> str :
104+ if self .field is None :
105+ type_hint = self .type_hint
106+ else :
107+ type_hint = (
108+ UsefulStr (self .field .type_hint )
109+ if not isinstance (self .field , list )
110+ else UsefulStr (
111+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
112+ )
113+ )
114+ if self .default is None and self .required :
115+ return f'{ self .name } : { type_hint } '
116+ return f'{ self .name } : { type_hint } = { self .default } '
117+
118+ @property
119+ def snakecase (self ) -> str :
120+ if self .field is None :
121+ type_hint = self .type_hint
122+ else :
123+ type_hint = (
124+ UsefulStr (self .field .type_hint )
125+ if not isinstance (self .field , list )
126+ else UsefulStr (
127+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
128+ )
129+ )
103130 if self .default is None and self .required :
104- return f'{ self .name } : { self . type_hint } '
105- return f'{ self .name } : { self . type_hint } = { self .default } '
131+ return f'{ stringcase . snakecase ( self .name ) } : { type_hint } '
132+ return f'{ stringcase . snakecase ( self .name ) } : { type_hint } = { self .default } '
106133
107134
108135class Operation (CachedPropertyModel ):
@@ -114,16 +141,39 @@ class Operation(CachedPropertyModel):
114141 parameters : List [Dict [str , Any ]] = []
115142 responses : Dict [UsefulStr , Any ] = {}
116143 deprecated : bool = False
117- imports : List [Import ] = []
118144 security : Optional [List [Dict [str , List [str ]]]] = None
119145 tags : Optional [List [str ]] = []
120- arguments : str = ''
121- snake_case_arguments : str = ''
122146 request : Optional [Argument ] = None
123147 response : str = ''
124148 additional_responses : Dict [Union [str , int ], Dict [str , str ]] = {}
125149 return_type : str = ''
126150 callbacks : Dict [UsefulStr , List ["Operation" ]] = {}
151+ arguments_list : List [Argument ] = []
152+
153+ @classmethod
154+ def merge_arguments_with_union (cls , arguments : List [Argument ]) -> List [Argument ]:
155+ grouped_arguments : DefaultDict [str , List [Argument ]] = DefaultDict (list )
156+ for argument in arguments :
157+ grouped_arguments [argument .name ].append (argument )
158+
159+ merged_arguments = []
160+ for argument_list in grouped_arguments .values ():
161+ if len (argument_list ) == 1 :
162+ merged_arguments .append (argument_list [0 ])
163+ else :
164+ argument = argument_list [0 ]
165+ fields = [
166+ item
167+ for arg in argument_list
168+ if arg .field is not None
169+ for item in (
170+ arg .field if isinstance (arg .field , list ) else [arg .field ]
171+ )
172+ if item is not None
173+ ]
174+ argument .field = fields
175+ merged_arguments .append (argument )
176+ return merged_arguments
127177
128178 @cached_property
129179 def type (self ) -> UsefulStr :
@@ -132,6 +182,27 @@ def type(self) -> UsefulStr:
132182 """
133183 return self .method
134184
185+ @property
186+ def arguments (self ) -> str :
187+ sorted_arguments = Operation .merge_arguments_with_union (self .arguments_list )
188+ return ", " .join (argument .argument for argument in sorted_arguments )
189+
190+ @property
191+ def snake_case_arguments (self ) -> str :
192+ sorted_arguments = Operation .merge_arguments_with_union (self .arguments_list )
193+ return ", " .join (argument .snakecase for argument in sorted_arguments )
194+
195+ @property
196+ def imports (self ) -> Imports :
197+ imports = Imports ()
198+ for argument in self .arguments_list :
199+ if isinstance (argument .field , list ):
200+ for field in argument .field :
201+ imports .append (field .data_type .import_ )
202+ elif argument .field :
203+ imports .append (argument .field .data_type .import_ )
204+ return imports
205+
135206 @cached_property
136207 def root_path (self ) -> UsefulStr :
137208 paths = self .path .split ("/" )
@@ -153,7 +224,7 @@ def function_name(self) -> str:
153224 return stringcase .snakecase (name )
154225
155226
156- @snooper_to_methods (max_variable_length = None )
227+ @snooper_to_methods ()
157228class OpenAPIParser (OpenAPIModelParser ):
158229 def __init__ (
159230 self ,
@@ -166,7 +237,7 @@ def __init__(
166237 base_class : Optional [str ] = None ,
167238 custom_template_dir : Optional [pathlib .Path ] = None ,
168239 extra_template_data : Optional [DefaultDict [str , Dict [str , Any ]]] = None ,
169- target_python_version : PythonVersion = PythonVersion .PY_37 ,
240+ target_python_version : PythonVersion = PythonVersion .PY_39 ,
170241 dump_resolve_reference_action : Optional [Callable [[Iterable [str ]], str ]] = None ,
171242 validation : bool = False ,
172243 field_constraints : bool = False ,
@@ -314,6 +385,7 @@ def get_parameter_type(
314385 default = default , # type: ignore
315386 default_value = schema .default ,
316387 required = field .required ,
388+ field = field ,
317389 )
318390
319391 def get_arguments (self , snake_case : bool , path : List [str ]) -> str :
@@ -347,6 +419,10 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347419 or argument .type_hint .startswith ('Optional[' )
348420 )
349421
422+ # check if there are duplicate argument.name
423+ argument_names = [argument .name for argument in arguments ]
424+ if len (argument_names ) != len (set (argument_names )):
425+ self .imports_for_fastapi .append (Import (from_ = 'typing' , import_ = "Union" ))
350426 return arguments
351427
352428 def parse_request_body (
@@ -466,10 +542,7 @@ def parse_operation(
466542 resolved_path = self .model_resolver .resolve_ref (path )
467543 path_name , method = path [- 2 :]
468544
469- self ._temporary_operation ['arguments' ] = self .get_arguments (
470- snake_case = False , path = path
471- )
472- self ._temporary_operation ['snake_case_arguments' ] = self .get_arguments (
545+ self ._temporary_operation ['arguments_list' ] = self .get_argument_list (
473546 snake_case = True , path = path
474547 )
475548 main_operation = self ._temporary_operation
@@ -499,11 +572,8 @@ def parse_operation(
499572 self ._temporary_operation = {'_parameters' : []}
500573 cb_path = path + ['callbacks' , key , route , method ]
501574 super ().parse_operation (cb_op , cb_path )
502- self ._temporary_operation ['arguments' ] = self .get_arguments (
503- snake_case = False , path = cb_path
504- )
505- self ._temporary_operation ['snake_case_arguments' ] = (
506- self .get_arguments (snake_case = True , path = cb_path )
575+ self ._temporary_operation ['arguments_list' ] = (
576+ self .get_argument_list (snake_case = True , path = cb_path )
507577 )
508578
509579 callbacks [key ].append (
@@ -527,13 +597,16 @@ def _collapse_root_model(self, data_type: DataType) -> DataType:
527597 reference = data_type .reference
528598 import functools
529599
530- if not (
531- reference
532- and (
533- len (reference .children ) == 1
534- or functools .reduce (lambda a , b : a == b , reference .children )
535- )
536- ):
600+ try :
601+ if not (
602+ reference
603+ and (
604+ len (reference .children ) == 0
605+ or functools .reduce (lambda a , b : a == b , reference .children )
606+ )
607+ ):
608+ return data_type
609+ except RecursionError :
537610 return data_type
538611 source = reference .source
539612 if not isinstance (source , CustomRootType ):
0 commit comments