Skip to content

Commit bba7578

Browse files
committed
Refactor Operation.arguments and Operation.snake_case_arguments
1 parent 0f47707 commit bba7578

File tree

6 files changed

+98
-65
lines changed

6 files changed

+98
-65
lines changed

fastapi_code_generator/parser.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,44 @@ class Argument(CachedPropertyModel):
9393
type_hint: UsefulStr
9494
default: Optional[UsefulStr] = None
9595
default_value: Optional[UsefulStr] = None
96+
field: Union[DataModelField, list[DataModelField], None] = None
9697
required: bool
9798

9899
def __str__(self) -> str:
99100
return self.argument
100101

101102
@cached_property
102103
def argument(self) -> str:
104+
if self.field is None:
105+
type_hint = self.type_hint
106+
else:
107+
type_hint = (
108+
UsefulStr(self.field.type_hint)
109+
if not isinstance(self.field, list)
110+
else UsefulStr(
111+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
112+
)
113+
)
103114
if self.default is None and self.required:
104-
return f'{self.name}: {self.type_hint}'
115+
return f'{self.name}: {type_hint}'
105116
return f'{self.name}: {self.type_hint} = {self.default}'
106117

118+
@cached_property
119+
def snakecase(self) -> str:
120+
if self.field is None:
121+
type_hint = self.type_hint
122+
else:
123+
type_hint = (
124+
UsefulStr(self.field.type_hint)
125+
if not isinstance(self.field, list)
126+
else UsefulStr(
127+
f"Union[{', '.join(field.type_hint for field in self.field)}]"
128+
)
129+
)
130+
if self.default is None and self.required:
131+
return f'{stringcase.snakecase(self.name)}: {type_hint}'
132+
return f'{stringcase.snakecase(self.name)}: {type_hint} = {self.default}'
133+
107134

108135
class Operation(CachedPropertyModel):
109136
method: UsefulStr
@@ -117,13 +144,39 @@ class Operation(CachedPropertyModel):
117144
imports: List[Import] = []
118145
security: Optional[List[Dict[str, List[str]]]] = None
119146
tags: Optional[List[str]] = []
120-
arguments: str = ''
121-
snake_case_arguments: str = ''
122147
request: Optional[Argument] = None
123148
response: str = ''
124149
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
125150
return_type: str = ''
126151
callbacks: Dict[UsefulStr, List["Operation"]] = {}
152+
arguments_list: List[Argument] = []
153+
154+
@classmethod
155+
def merge_arguments_with_union(
156+
cls, arguments: List[Argument], imports: List[Import]
157+
) -> List[Argument]:
158+
grouped_arguments: DefaultDict[str, List[Argument]] = DefaultDict(list)
159+
for argument in arguments:
160+
grouped_arguments[argument.name].append(argument)
161+
162+
merged_arguments = []
163+
for argument_list in grouped_arguments.values():
164+
if len(argument_list) == 1:
165+
merged_arguments.append(argument_list[0])
166+
else:
167+
argument = argument_list[0]
168+
fields = [
169+
item
170+
for arg in argument_list
171+
if arg.field is not None
172+
for item in (
173+
arg.field if isinstance(arg.field, list) else [arg.field]
174+
)
175+
if item is not None
176+
]
177+
argument.field = fields
178+
merged_arguments.append(argument)
179+
return merged_arguments
127180

128181
@cached_property
129182
def type(self) -> UsefulStr:
@@ -132,6 +185,20 @@ def type(self) -> UsefulStr:
132185
"""
133186
return self.method
134187

188+
@property
189+
def arguments(self) -> str:
190+
sorted_arguments = Operation.merge_arguments_with_union(
191+
self.arguments_list, self.imports
192+
)
193+
return ", ".join(argument.argument for argument in sorted_arguments)
194+
195+
@property
196+
def snake_case_arguments(self) -> str:
197+
sorted_arguments = Operation.merge_arguments_with_union(
198+
self.arguments_list, self.imports
199+
)
200+
return ", ".join(argument.snakecase for argument in sorted_arguments)
201+
135202
@cached_property
136203
def root_path(self) -> UsefulStr:
137204
paths = self.path.split("/")
@@ -314,6 +381,7 @@ def get_parameter_type(
314381
default=default, # type: ignore
315382
default_value=schema.default,
316383
required=field.required,
384+
field=field,
317385
)
318386

319387
def get_arguments(self, snake_case: bool, path: List[str]) -> str:
@@ -347,22 +415,11 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347415
or argument.type_hint.startswith('Optional[')
348416
)
349417

350-
# Group argument with same name into one argument.name argument.type_hint into Argument(name = argument.name, type_hint = "Union[argument.type_hint, argument.type_hint]")
351-
grouped_arguments: DefaultDict[str, List[Argument]] = DefaultDict(list)
352-
sorted_arguments = []
353-
for argument in arguments:
354-
grouped_arguments[argument.name].append(argument)
355-
for argument_list in grouped_arguments.values():
356-
if len(argument_list) == 1:
357-
sorted_arguments.append(argument_list[0])
358-
else:
359-
argument = argument_list[0]
360-
type_hints = [arg.type_hint for arg in argument_list]
361-
argument.type_hint = UsefulStr(f"Union[{', '.join(type_hints)}]")
362-
self.imports_for_fastapi.append(Import(from_='typing', import_="Union"))
363-
sorted_arguments.append(argument)
364-
365-
return sorted_arguments
418+
# check if there are duplicate argument.name
419+
argument_names = [argument.name for argument in arguments]
420+
if len(argument_names) != len(set(argument_names)):
421+
self.imports_for_fastapi.append(Import(from_='typing', import_="Union"))
422+
return arguments
366423

367424
def parse_request_body(
368425
self,
@@ -481,10 +538,7 @@ def parse_operation(
481538
resolved_path = self.model_resolver.resolve_ref(path)
482539
path_name, method = path[-2:]
483540

484-
self._temporary_operation['arguments'] = self.get_arguments(
485-
snake_case=False, path=path
486-
)
487-
self._temporary_operation['snake_case_arguments'] = self.get_arguments(
541+
self._temporary_operation['arguments_list'] = self.get_argument_list(
488542
snake_case=True, path=path
489543
)
490544
main_operation = self._temporary_operation
@@ -514,11 +568,8 @@ def parse_operation(
514568
self._temporary_operation = {'_parameters': []}
515569
cb_path = path + ['callbacks', key, route, method]
516570
super().parse_operation(cb_op, cb_path)
517-
self._temporary_operation['arguments'] = self.get_arguments(
518-
snake_case=False, path=cb_path
519-
)
520-
self._temporary_operation['snake_case_arguments'] = (
521-
self.get_arguments(snake_case=True, path=cb_path)
571+
self._temporary_operation['arguments_list'] = (
572+
self.get_argument_list(snake_case=True, path=cb_path)
522573
)
523574

524575
callbacks[key].append(

here/main.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

here/models.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

tests/data/expected/openapi/default_template/duplicate_anonymus_parameter/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fastapi import FastAPI
1010

11-
from .models import HTTPValidationError
11+
from .models import HTTPValidationError, SpreadsheetId
1212

1313
app = FastAPI(
1414
title='test',
@@ -23,7 +23,7 @@
2323
'/get-sheet', response_model=str, responses={'422': {'model': HTTPValidationError}}
2424
)
2525
def get_sheet_get_sheet_get(
26-
spreadsheet_id: Optional[str] = None,
26+
spreadsheet_id: Optional[SpreadsheetId] = 'none',
2727
) -> Union[str, HTTPValidationError]:
2828
"""
2929
Get Sheet
@@ -37,7 +37,7 @@ def get_sheet_get_sheet_get(
3737
responses={'422': {'model': HTTPValidationError}},
3838
)
3939
def update_sheet_update_sheet_post(
40-
spreadsheet_id: Optional[str] = None,
40+
spreadsheet_id: Optional[SpreadsheetId] = 'none',
4141
) -> Union[str, HTTPValidationError]:
4242
"""
4343
Update Sheet

tests/data/expected/openapi/default_template/duplicate_anonymus_parameter/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from enum import Enum
78
from typing import List, Optional, Union
89

910
from pydantic import BaseModel, Field
@@ -15,5 +16,10 @@ class ValidationError(BaseModel):
1516
type: str = Field(..., title='Error Type')
1617

1718

19+
class SpreadsheetId(Enum):
20+
cards = 'cards'
21+
none = 'none'
22+
23+
1824
class HTTPValidationError(BaseModel):
1925
detail: Optional[List[ValidationError]] = Field(None, title='Detail')

tests/data/openapi/default_template/duplicate_anonymus_parameter.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ paths:
1616
in: query
1717
required: false
1818
schema:
19-
anyOf:
20-
- type: string
21-
- type: 'null'
19+
type: string
20+
default: none
21+
enum:
22+
- cards
23+
- none
2224
title: Spreadsheet Id
2325
responses:
2426
'200':
@@ -44,9 +46,11 @@ paths:
4446
in: query
4547
required: false
4648
schema:
47-
anyOf:
48-
- type: string
49-
- type: 'null'
49+
type: string
50+
default: none
51+
enum:
52+
- cards
53+
- none
5054
title: Spreadsheet Id
5155
responses:
5256
'200':

0 commit comments

Comments
 (0)