@@ -93,17 +93,44 @@ class Argument(CachedPropertyModel):
9393 type_hint : UsefulStr
9494 default : Optional [UsefulStr ] = None
9595 default_value : Optional [UsefulStr ] = None
96+ field : Union [DataModelField , list [DataModelField ], None ] = None
9697 required : bool
9798
9899 def __str__ (self ) -> str :
99100 return self .argument
100101
101102 @cached_property
102103 def argument (self ) -> str :
104+ if self .field is None :
105+ type_hint = self .type_hint
106+ else :
107+ type_hint = (
108+ UsefulStr (self .field .type_hint )
109+ if not isinstance (self .field , list )
110+ else UsefulStr (
111+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
112+ )
113+ )
103114 if self .default is None and self .required :
104- return f'{ self .name } : { self . type_hint } '
115+ return f'{ self .name } : { type_hint } '
105116 return f'{ self .name } : { self .type_hint } = { self .default } '
106117
118+ @cached_property
119+ def snakecase (self ) -> str :
120+ if self .field is None :
121+ type_hint = self .type_hint
122+ else :
123+ type_hint = (
124+ UsefulStr (self .field .type_hint )
125+ if not isinstance (self .field , list )
126+ else UsefulStr (
127+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
128+ )
129+ )
130+ if self .default is None and self .required :
131+ return f'{ stringcase .snakecase (self .name )} : { type_hint } '
132+ return f'{ stringcase .snakecase (self .name )} : { type_hint } = { self .default } '
133+
107134
108135class Operation (CachedPropertyModel ):
109136 method : UsefulStr
@@ -117,13 +144,39 @@ class Operation(CachedPropertyModel):
117144 imports : List [Import ] = []
118145 security : Optional [List [Dict [str , List [str ]]]] = None
119146 tags : Optional [List [str ]] = []
120- arguments : str = ''
121- snake_case_arguments : str = ''
122147 request : Optional [Argument ] = None
123148 response : str = ''
124149 additional_responses : Dict [Union [str , int ], Dict [str , str ]] = {}
125150 return_type : str = ''
126151 callbacks : Dict [UsefulStr , List ["Operation" ]] = {}
152+ arguments_list : List [Argument ] = []
153+
154+ @classmethod
155+ def merge_arguments_with_union (
156+ cls , arguments : List [Argument ], imports : List [Import ]
157+ ) -> List [Argument ]:
158+ grouped_arguments : DefaultDict [str , List [Argument ]] = DefaultDict (list )
159+ for argument in arguments :
160+ grouped_arguments [argument .name ].append (argument )
161+
162+ merged_arguments = []
163+ for argument_list in grouped_arguments .values ():
164+ if len (argument_list ) == 1 :
165+ merged_arguments .append (argument_list [0 ])
166+ else :
167+ argument = argument_list [0 ]
168+ fields = [
169+ item
170+ for arg in argument_list
171+ if arg .field is not None
172+ for item in (
173+ arg .field if isinstance (arg .field , list ) else [arg .field ]
174+ )
175+ if item is not None
176+ ]
177+ argument .field = fields
178+ merged_arguments .append (argument )
179+ return merged_arguments
127180
128181 @cached_property
129182 def type (self ) -> UsefulStr :
@@ -132,6 +185,20 @@ def type(self) -> UsefulStr:
132185 """
133186 return self .method
134187
188+ @property
189+ def arguments (self ) -> str :
190+ sorted_arguments = Operation .merge_arguments_with_union (
191+ self .arguments_list , self .imports
192+ )
193+ return ", " .join (argument .argument for argument in sorted_arguments )
194+
195+ @property
196+ def snake_case_arguments (self ) -> str :
197+ sorted_arguments = Operation .merge_arguments_with_union (
198+ self .arguments_list , self .imports
199+ )
200+ return ", " .join (argument .snakecase for argument in sorted_arguments )
201+
135202 @cached_property
136203 def root_path (self ) -> UsefulStr :
137204 paths = self .path .split ("/" )
@@ -314,6 +381,7 @@ def get_parameter_type(
314381 default = default , # type: ignore
315382 default_value = schema .default ,
316383 required = field .required ,
384+ field = field ,
317385 )
318386
319387 def get_arguments (self , snake_case : bool , path : List [str ]) -> str :
@@ -347,22 +415,11 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347415 or argument .type_hint .startswith ('Optional[' )
348416 )
349417
350- # Group argument with same name into one argument.name argument.type_hint into Argument(name = argument.name, type_hint = "Union[argument.type_hint, argument.type_hint]")
351- grouped_arguments : DefaultDict [str , List [Argument ]] = DefaultDict (list )
352- sorted_arguments = []
353- for argument in arguments :
354- grouped_arguments [argument .name ].append (argument )
355- for argument_list in grouped_arguments .values ():
356- if len (argument_list ) == 1 :
357- sorted_arguments .append (argument_list [0 ])
358- else :
359- argument = argument_list [0 ]
360- type_hints = [arg .type_hint for arg in argument_list ]
361- argument .type_hint = UsefulStr (f"Union[{ ', ' .join (type_hints )} ]" )
362- self .imports_for_fastapi .append (Import (from_ = 'typing' , import_ = "Union" ))
363- sorted_arguments .append (argument )
364-
365- return sorted_arguments
418+ # check if there are duplicate argument.name
419+ argument_names = [argument .name for argument in arguments ]
420+ if len (argument_names ) != len (set (argument_names )):
421+ self .imports_for_fastapi .append (Import (from_ = 'typing' , import_ = "Union" ))
422+ return arguments
366423
367424 def parse_request_body (
368425 self ,
@@ -481,10 +538,7 @@ def parse_operation(
481538 resolved_path = self .model_resolver .resolve_ref (path )
482539 path_name , method = path [- 2 :]
483540
484- self ._temporary_operation ['arguments' ] = self .get_arguments (
485- snake_case = False , path = path
486- )
487- self ._temporary_operation ['snake_case_arguments' ] = self .get_arguments (
541+ self ._temporary_operation ['arguments_list' ] = self .get_argument_list (
488542 snake_case = True , path = path
489543 )
490544 main_operation = self ._temporary_operation
@@ -514,11 +568,8 @@ def parse_operation(
514568 self ._temporary_operation = {'_parameters' : []}
515569 cb_path = path + ['callbacks' , key , route , method ]
516570 super ().parse_operation (cb_op , cb_path )
517- self ._temporary_operation ['arguments' ] = self .get_arguments (
518- snake_case = False , path = cb_path
519- )
520- self ._temporary_operation ['snake_case_arguments' ] = (
521- self .get_arguments (snake_case = True , path = cb_path )
571+ self ._temporary_operation ['arguments_list' ] = (
572+ self .get_argument_list (snake_case = True , path = cb_path )
522573 )
523574
524575 callbacks [key ].append (
0 commit comments