Skip to content

Commit 86bd218

Browse files
authored
Merge pull request #2 from ag2ai/Bugfixing
Bugfixing
2 parents d8e3b66 + f1cc623 commit 86bd218

File tree

16 files changed

+603
-31
lines changed

16 files changed

+603
-31
lines changed

fastapi_code_generator/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .patches import patch_parse
2+
3+
patch_parse()

fastapi_code_generator/parser.py

Lines changed: 95 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,43 @@ class Argument(CachedPropertyModel):
9393
type_hint: UsefulStr
9494
default: Optional[UsefulStr] = None
9595
default_value: Optional[UsefulStr] = None
96+
field: Union[DataModelField, list[DataModelField], None] = None
9697
required: bool
9798

9899
def __str__(self) -> str:
99100
return self.argument
100101

101-
@cached_property
102+
@property
102103
def argument(self) -> str:
104+
if self.field is None:
105+
type_hint = self.type_hint
106+
else:
107+
type_hint = (
108+
UsefulStr(self.field.type_hint)
109+
if not isinstance(self.field, list)
110+
else UsefulStr(
111+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
112+
)
113+
)
114+
if self.default is None and self.required:
115+
return f'{self.name}: {type_hint}'
116+
return f'{self.name}: {type_hint} = {self.default}'
117+
118+
@property
119+
def snakecase(self) -> str:
120+
if self.field is None:
121+
type_hint = self.type_hint
122+
else:
123+
type_hint = (
124+
UsefulStr(self.field.type_hint)
125+
if not isinstance(self.field, list)
126+
else UsefulStr(
127+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
128+
)
129+
)
103130
if self.default is None and self.required:
104-
return f'{self.name}: {self.type_hint}'
105-
return f'{self.name}: {self.type_hint} = {self.default}'
131+
return f'{stringcase.snakecase(self.name)}: {type_hint}'
132+
return f'{stringcase.snakecase(self.name)}: {type_hint} = {self.default}'
106133

107134

108135
class Operation(CachedPropertyModel):
@@ -114,16 +141,39 @@ class Operation(CachedPropertyModel):
114141
parameters: List[Dict[str, Any]] = []
115142
responses: Dict[UsefulStr, Any] = {}
116143
deprecated: bool = False
117-
imports: List[Import] = []
118144
security: Optional[List[Dict[str, List[str]]]] = None
119145
tags: Optional[List[str]] = []
120-
arguments: str = ''
121-
snake_case_arguments: str = ''
122146
request: Optional[Argument] = None
123147
response: str = ''
124148
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
125149
return_type: str = ''
126150
callbacks: Dict[UsefulStr, List["Operation"]] = {}
151+
arguments_list: List[Argument] = []
152+
153+
@classmethod
154+
def merge_arguments_with_union(cls, arguments: List[Argument]) -> List[Argument]:
155+
grouped_arguments: DefaultDict[str, List[Argument]] = DefaultDict(list)
156+
for argument in arguments:
157+
grouped_arguments[argument.name].append(argument)
158+
159+
merged_arguments = []
160+
for argument_list in grouped_arguments.values():
161+
if len(argument_list) == 1:
162+
merged_arguments.append(argument_list[0])
163+
else:
164+
argument = argument_list[0]
165+
fields = [
166+
item
167+
for arg in argument_list
168+
if arg.field is not None
169+
for item in (
170+
arg.field if isinstance(arg.field, list) else [arg.field]
171+
)
172+
if item is not None
173+
]
174+
argument.field = fields
175+
merged_arguments.append(argument)
176+
return merged_arguments
127177

128178
@cached_property
129179
def type(self) -> UsefulStr:
@@ -132,6 +182,27 @@ def type(self) -> UsefulStr:
132182
"""
133183
return self.method
134184

185+
@property
186+
def arguments(self) -> str:
187+
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
188+
return ", ".join(argument.argument for argument in sorted_arguments)
189+
190+
@property
191+
def snake_case_arguments(self) -> str:
192+
sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list)
193+
return ", ".join(argument.snakecase for argument in sorted_arguments)
194+
195+
@property
196+
def imports(self) -> Imports:
197+
imports = Imports()
198+
for argument in self.arguments_list:
199+
if isinstance(argument.field, list):
200+
for field in argument.field:
201+
imports.append(field.data_type.import_)
202+
elif argument.field:
203+
imports.append(argument.field.data_type.import_)
204+
return imports
205+
135206
@cached_property
136207
def root_path(self) -> UsefulStr:
137208
paths = self.path.split("/")
@@ -314,6 +385,7 @@ def get_parameter_type(
314385
default=default, # type: ignore
315386
default_value=schema.default,
316387
required=field.required,
388+
field=field,
317389
)
318390

319391
def get_arguments(self, snake_case: bool, path: List[str]) -> str:
@@ -347,6 +419,10 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347419
or argument.type_hint.startswith('Optional[')
348420
)
349421

422+
# check if there are duplicate argument.name
423+
argument_names = [argument.name for argument in arguments]
424+
if len(argument_names) != len(set(argument_names)):
425+
self.imports_for_fastapi.append(Import(from_='typing', import_="Union"))
350426
return arguments
351427

352428
def parse_request_body(
@@ -466,10 +542,7 @@ def parse_operation(
466542
resolved_path = self.model_resolver.resolve_ref(path)
467543
path_name, method = path[-2:]
468544

469-
self._temporary_operation['arguments'] = self.get_arguments(
470-
snake_case=False, path=path
471-
)
472-
self._temporary_operation['snake_case_arguments'] = self.get_arguments(
545+
self._temporary_operation['arguments_list'] = self.get_argument_list(
473546
snake_case=True, path=path
474547
)
475548
main_operation = self._temporary_operation
@@ -499,11 +572,8 @@ def parse_operation(
499572
self._temporary_operation = {'_parameters': []}
500573
cb_path = path + ['callbacks', key, route, method]
501574
super().parse_operation(cb_op, cb_path)
502-
self._temporary_operation['arguments'] = self.get_arguments(
503-
snake_case=False, path=cb_path
504-
)
505-
self._temporary_operation['snake_case_arguments'] = (
506-
self.get_arguments(snake_case=True, path=cb_path)
575+
self._temporary_operation['arguments_list'] = (
576+
self.get_argument_list(snake_case=True, path=cb_path)
507577
)
508578

509579
callbacks[key].append(
@@ -527,13 +597,16 @@ def _collapse_root_model(self, data_type: DataType) -> DataType:
527597
reference = data_type.reference
528598
import functools
529599

530-
if not (
531-
reference
532-
and (
533-
len(reference.children) == 1
534-
or functools.reduce(lambda a, b: a == b, reference.children)
535-
)
536-
):
600+
try:
601+
if not (
602+
reference
603+
and (
604+
len(reference.children) == 0
605+
or functools.reduce(lambda a, b: a == b, reference.children)
606+
)
607+
):
608+
return data_type
609+
except RecursionError:
537610
return data_type
538611
source = reference.source
539612
if not isinstance(source, CustomRootType):

0 commit comments

Comments
 (0)