Skip to content

Commit 0f47707

Browse files
committed
Code formatting
1 parent 763098c commit 0f47707

File tree

3 files changed

+112
-37
lines changed

3 files changed

+112
-37
lines changed

fastapi_code_generator/__init__.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import logging
5+
from itertools import groupby
56
from pathlib import Path
6-
from typing import Union
7+
from typing import NamedTuple
78

8-
logger = logging.getLogger(__name__)
9-
10-
from datamodel_code_generator.parser import base
9+
from datamodel_code_generator.format import CodeFormatter
10+
from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports
1111
from datamodel_code_generator.model.base import DataModel
12+
from datamodel_code_generator.parser import base
13+
from datamodel_code_generator.reference import ModelResolver
14+
15+
logger = logging.getLogger(__name__)
1216

1317
# Save the original method before patching
1418
original_parse = base.Parser.parse
1519

20+
1621
def patch_parse() -> None: # noqa: C901
1722
def __alias_shadowed_imports(
18-
self,
23+
self: base.Parser,
1924
models: list[DataModel],
2025
all_model_field_names: set[str],
2126
) -> None:
@@ -24,21 +29,22 @@ def __alias_shadowed_imports(
2429
if model_field.data_type.type in all_model_field_names:
2530
alias = model_field.data_type.type + "_aliased"
2631
model_field.data_type.type = alias
27-
model_field.data_type.import_.alias = alias
32+
if model_field.data_type.import_:
33+
model_field.data_type.import_.alias = alias
2834

2935
def _parse( # noqa: PLR0912, PLR0914, PLR0915
30-
self,
36+
self: base.Parser,
3137
with_import: bool | None = True, # noqa: FBT001, FBT002
3238
format_: bool | None = True, # noqa: FBT001, FBT002
3339
settings_path: Path | None = None,
3440
) -> str | dict[tuple[str, ...], base.Result]:
3541
self.parse_raw()
3642

3743
if with_import:
38-
self.imports.append(base.IMPORT_ANNOTATIONS)
44+
self.imports.append(IMPORT_ANNOTATIONS)
3945

4046
if format_:
41-
code_formatter: base.CodeFormatter | None = base.CodeFormatter(
47+
code_formatter: CodeFormatter | None = CodeFormatter(
4248
self.target_python_version,
4349
settings_path,
4450
self.wrap_string_literal,
@@ -52,7 +58,9 @@ def _parse( # noqa: PLR0912, PLR0914, PLR0915
5258
else:
5359
code_formatter = None
5460

55-
_, sorted_data_models, require_update_action_models = base.sort_data_models(self.results)
61+
_, sorted_data_models, require_update_action_models = base.sort_data_models(
62+
self.results
63+
)
5664

5765
results: dict[tuple[str, ...], base.Result] = {}
5866

@@ -63,15 +71,17 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
6371
return (len(data_model.module_path), tuple(data_model.module_path))
6472

6573
# process in reverse order to correctly establish module levels
66-
grouped_models = base.groupby(
74+
grouped_models = groupby(
6775
sorted(sorted_data_models.values(), key=sort_key, reverse=True),
6876
key=module_key,
6977
)
7078

7179
module_models: list[tuple[tuple[str, ...], list[DataModel]]] = []
7280
unused_models: list[DataModel] = []
73-
model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {}
74-
module_to_import: dict[tuple[str, ...], base.Imports] = {}
81+
model_to_module_models: dict[
82+
DataModel, tuple[tuple[str, ...], list[DataModel]]
83+
] = {}
84+
module_to_import: dict[tuple[str, ...], Imports] = {}
7585

7686
previous_module: tuple[str, ...] = ()
7787
for module, models in ((k, [*v]) for k, v in grouped_models):
@@ -87,23 +97,25 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
8797
)
8898
for parts in range(len(previous_module) - 1, len(module), -1)
8999
)
90-
module_models.append((
91-
module,
92-
models,
93-
))
100+
module_models.append(
101+
(
102+
module,
103+
models,
104+
)
105+
)
94106
previous_module = module
95107

96-
class Processed(base.NamedTuple):
108+
class Processed(NamedTuple):
97109
module: tuple[str, ...]
98110
models: list[DataModel]
99111
init: bool
100-
imports: base.Imports
101-
scoped_model_resolver: base.ModelResolver
112+
imports: Imports
113+
scoped_model_resolver: ModelResolver
102114

103115
processed_models: list[Processed] = []
104116

105117
for module_, models in module_models:
106-
imports = module_to_import[module_] = base.Imports(self.use_exact_imports)
118+
imports = module_to_import[module_] = Imports(self.use_exact_imports)
107119
init = False
108120
if module_:
109121
parent = (*module_[:-1], "__init__.py")
@@ -113,28 +125,42 @@ class Processed(base.NamedTuple):
113125
module = (*module_, "__init__.py")
114126
init = True
115127
else:
116-
module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py"))
128+
module = tuple(
129+
part.replace("-", "_")
130+
for part in (*module_[:-1], f"{module_[-1]}.py")
131+
)
117132
else:
118133
module = ("__init__.py",)
119134

120-
all_module_fields = {field.name for model in models for field in model.fields if field.name is not None}
121-
scoped_model_resolver = base.ModelResolver(exclude_names=all_module_fields)
135+
all_module_fields = {
136+
field.name
137+
for model in models
138+
for field in model.fields
139+
if field.name is not None
140+
}
141+
scoped_model_resolver = ModelResolver(exclude_names=all_module_fields)
122142

123143
self.__alias_shadowed_imports(models, all_module_fields)
124144
self._Parser__override_required_field(models)
125145
self._Parser__replace_unique_list_to_set(models)
126-
self._Parser__change_from_import(models, imports, scoped_model_resolver, init)
146+
self._Parser__change_from_import(
147+
models, imports, scoped_model_resolver, init
148+
)
127149
self._Parser__extract_inherited_enum(models)
128150
self._Parser__set_reference_default_value_to_field(models)
129151
self._Parser__reuse_model(models, require_update_action_models)
130-
self._Parser__collapse_root_models(models, unused_models, imports, scoped_model_resolver)
152+
self._Parser__collapse_root_models(
153+
models, unused_models, imports, scoped_model_resolver
154+
)
131155
self._Parser__set_default_enum_member(models)
132156
self._Parser__sort_models(models, imports)
133157
self._Parser__change_field_name(models)
134158
self._Parser__apply_discriminator_type(models, imports)
135159
self._Parser__set_one_literal_on_default(models)
136160

137-
processed_models.append(Processed(module, models, init, imports, scoped_model_resolver))
161+
processed_models.append(
162+
Processed(module, models, init, imports, scoped_model_resolver)
163+
)
138164

139165
for processed_model in processed_models:
140166
for model in processed_model.models:
@@ -159,11 +185,25 @@ class Processed(base.NamedTuple):
159185
for from_, import_ in unused_imports:
160186
processed_model.imports.remove(Import(from_=from_, import_=import_))
161187

162-
for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
188+
for (
189+
module,
190+
models,
191+
init,
192+
imports,
193+
scoped_model_resolver,
194+
) in processed_models: # noqa: B007
163195
# process after removing unused models
164-
self._Parser__change_imported_model_name(models, imports, scoped_model_resolver)
196+
self._Parser__change_imported_model_name(
197+
models, imports, scoped_model_resolver
198+
)
165199

166-
for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
200+
for (
201+
module,
202+
models,
203+
init,
204+
imports,
205+
scoped_model_resolver,
206+
) in processed_models: # noqa: B007
167207
result: list[str] = []
168208
if models:
169209
if with_import:
@@ -176,7 +216,9 @@ class Processed(base.NamedTuple):
176216
result += [
177217
"\n",
178218
self.dump_resolve_reference_action(
179-
m.reference.short_name for m in models if m.path in require_update_action_models
219+
m.reference.short_name
220+
for m in models
221+
if m.path in require_update_action_models
180222
),
181223
]
182224
if not result and not init:
@@ -185,7 +227,9 @@ class Processed(base.NamedTuple):
185227
if code_formatter:
186228
body = code_formatter.format_code(body)
187229

188-
results[module] = base.Result(body=body, source=models[0].file_path if models else None)
230+
results[module] = base.Result(
231+
body=body, source=models[0].file_path if models else None
232+
)
189233

190234
# retain existing behaviour
191235
if [*results] == [("__init__.py",)]:
@@ -196,17 +240,20 @@ class Processed(base.NamedTuple):
196240
self._Parser__postprocess_result_modules(results)
197241
if self.treat_dot_as_module
198242
else {
199-
tuple((part[: part.rfind(".")].replace(".", "_") + part[part.rfind(".") :]) for part in k): v
243+
tuple(
244+
(
245+
part[: part.rfind(".")].replace(".", "_")
246+
+ part[part.rfind(".") :]
247+
)
248+
for part in k
249+
): v
200250
for k, v in results.items()
201251
}
202252
)
203253

204-
205254
base.Parser.parse = _parse
206255
base.Parser.__alias_shadowed_imports = __alias_shadowed_imports
207-
208256
logger.info("Patched Parser.parse method.")
209257

210-
patch_parse()
211-
212258

259+
patch_parse()

here/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# generated by fastapi-codegen:
2+
# filename: shadowed_imports.yaml
3+
# timestamp: 2020-06-19T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from fastapi import FastAPI
8+
9+
app = FastAPI(
10+
title='REST API',
11+
version='0.0.1',
12+
servers=[{'url': 'https://api.something.com/1'}],
13+
)

here/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# generated by fastapi-codegen:
2+
# filename: shadowed_imports.yaml
3+
# timestamp: 2020-06-19T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from datetime import date as date_aliased
8+
from typing import Optional
9+
10+
from pydantic import BaseModel, Field
11+
12+
13+
class MarketingOptIn(BaseModel):
14+
optedIn: Optional[bool] = Field(None, example=False)
15+
date: Optional[date_aliased] = Field(None, example='2018-04-26T17:03:25.155Z')

0 commit comments

Comments
 (0)