Skip to content

Commit 763098c

Browse files
committed
Patch parse method to fix shadowed imports issue with aliasing
1 parent a54038c commit 763098c

File tree

4 files changed

+259
-0
lines changed

4 files changed

+259
-0
lines changed

fastapi_code_generator/__init__.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import logging
5+
from pathlib import Path
6+
from typing import Union
7+
8+
logger = logging.getLogger(__name__)
9+
10+
from datamodel_code_generator.parser import base
11+
from datamodel_code_generator.model.base import DataModel
12+
13+
# Save the original method before patching
14+
original_parse = base.Parser.parse
15+
16+
def patch_parse() -> None: # noqa: C901
17+
def __alias_shadowed_imports(
18+
self,
19+
models: list[DataModel],
20+
all_model_field_names: set[str],
21+
) -> None:
22+
for model in models:
23+
for model_field in model.fields:
24+
if model_field.data_type.type in all_model_field_names:
25+
alias = model_field.data_type.type + "_aliased"
26+
model_field.data_type.type = alias
27+
model_field.data_type.import_.alias = alias
28+
29+
def _parse( # noqa: PLR0912, PLR0914, PLR0915
30+
self,
31+
with_import: bool | None = True, # noqa: FBT001, FBT002
32+
format_: bool | None = True, # noqa: FBT001, FBT002
33+
settings_path: Path | None = None,
34+
) -> str | dict[tuple[str, ...], base.Result]:
35+
self.parse_raw()
36+
37+
if with_import:
38+
self.imports.append(base.IMPORT_ANNOTATIONS)
39+
40+
if format_:
41+
code_formatter: base.CodeFormatter | None = base.CodeFormatter(
42+
self.target_python_version,
43+
settings_path,
44+
self.wrap_string_literal,
45+
skip_string_normalization=not self.use_double_quotes,
46+
known_third_party=self.known_third_party,
47+
custom_formatters=self.custom_formatter,
48+
custom_formatters_kwargs=self.custom_formatters_kwargs,
49+
encoding=self.encoding,
50+
formatters=self.formatters,
51+
)
52+
else:
53+
code_formatter = None
54+
55+
_, sorted_data_models, require_update_action_models = base.sort_data_models(self.results)
56+
57+
results: dict[tuple[str, ...], base.Result] = {}
58+
59+
def module_key(data_model: DataModel) -> tuple[str, ...]:
60+
return tuple(data_model.module_path)
61+
62+
def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
63+
return (len(data_model.module_path), tuple(data_model.module_path))
64+
65+
# process in reverse order to correctly establish module levels
66+
grouped_models = base.groupby(
67+
sorted(sorted_data_models.values(), key=sort_key, reverse=True),
68+
key=module_key,
69+
)
70+
71+
module_models: list[tuple[tuple[str, ...], list[DataModel]]] = []
72+
unused_models: list[DataModel] = []
73+
model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {}
74+
module_to_import: dict[tuple[str, ...], base.Imports] = {}
75+
76+
previous_module: tuple[str, ...] = ()
77+
for module, models in ((k, [*v]) for k, v in grouped_models):
78+
for model in models:
79+
model_to_module_models[model] = module, models
80+
self._Parser__delete_duplicate_models(models)
81+
self._Parser__replace_duplicate_name_in_module(models)
82+
if len(previous_module) - len(module) > 1:
83+
module_models.extend(
84+
(
85+
previous_module[:parts],
86+
[],
87+
)
88+
for parts in range(len(previous_module) - 1, len(module), -1)
89+
)
90+
module_models.append((
91+
module,
92+
models,
93+
))
94+
previous_module = module
95+
96+
class Processed(base.NamedTuple):
97+
module: tuple[str, ...]
98+
models: list[DataModel]
99+
init: bool
100+
imports: base.Imports
101+
scoped_model_resolver: base.ModelResolver
102+
103+
processed_models: list[Processed] = []
104+
105+
for module_, models in module_models:
106+
imports = module_to_import[module_] = base.Imports(self.use_exact_imports)
107+
init = False
108+
if module_:
109+
parent = (*module_[:-1], "__init__.py")
110+
if parent not in results:
111+
results[parent] = base.Result(body="")
112+
if (*module_, "__init__.py") in results:
113+
module = (*module_, "__init__.py")
114+
init = True
115+
else:
116+
module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py"))
117+
else:
118+
module = ("__init__.py",)
119+
120+
all_module_fields = {field.name for model in models for field in model.fields if field.name is not None}
121+
scoped_model_resolver = base.ModelResolver(exclude_names=all_module_fields)
122+
123+
self.__alias_shadowed_imports(models, all_module_fields)
124+
self._Parser__override_required_field(models)
125+
self._Parser__replace_unique_list_to_set(models)
126+
self._Parser__change_from_import(models, imports, scoped_model_resolver, init)
127+
self._Parser__extract_inherited_enum(models)
128+
self._Parser__set_reference_default_value_to_field(models)
129+
self._Parser__reuse_model(models, require_update_action_models)
130+
self._Parser__collapse_root_models(models, unused_models, imports, scoped_model_resolver)
131+
self._Parser__set_default_enum_member(models)
132+
self._Parser__sort_models(models, imports)
133+
self._Parser__change_field_name(models)
134+
self._Parser__apply_discriminator_type(models, imports)
135+
self._Parser__set_one_literal_on_default(models)
136+
137+
processed_models.append(Processed(module, models, init, imports, scoped_model_resolver))
138+
139+
for processed_model in processed_models:
140+
for model in processed_model.models:
141+
processed_model.imports.append(model.imports)
142+
143+
for unused_model in unused_models:
144+
module, models = model_to_module_models[unused_model]
145+
if unused_model in models: # pragma: no cover
146+
imports = module_to_import[module]
147+
imports.remove(unused_model.imports)
148+
models.remove(unused_model)
149+
150+
for processed_model in processed_models:
151+
# postprocess imports to remove unused imports.
152+
model_code = str("\n".join([str(m) for m in processed_model.models]))
153+
unused_imports = [
154+
(from_, import_)
155+
for from_, imports_ in processed_model.imports.items()
156+
for import_ in imports_
157+
if import_ not in model_code
158+
]
159+
for from_, import_ in unused_imports:
160+
processed_model.imports.remove(Import(from_=from_, import_=import_))
161+
162+
for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
163+
# process after removing unused models
164+
self._Parser__change_imported_model_name(models, imports, scoped_model_resolver)
165+
166+
for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007
167+
result: list[str] = []
168+
if models:
169+
if with_import:
170+
result += [str(self.imports), str(imports), "\n"]
171+
172+
code = base.dump_templates(models)
173+
result += [code]
174+
175+
if self.dump_resolve_reference_action is not None:
176+
result += [
177+
"\n",
178+
self.dump_resolve_reference_action(
179+
m.reference.short_name for m in models if m.path in require_update_action_models
180+
),
181+
]
182+
if not result and not init:
183+
continue
184+
body = "\n".join(result)
185+
if code_formatter:
186+
body = code_formatter.format_code(body)
187+
188+
results[module] = base.Result(body=body, source=models[0].file_path if models else None)
189+
190+
# retain existing behaviour
191+
if [*results] == [("__init__.py",)]:
192+
return results["__init__.py",].body
193+
194+
results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()}
195+
return (
196+
self._Parser__postprocess_result_modules(results)
197+
if self.treat_dot_as_module
198+
else {
199+
tuple((part[: part.rfind(".")].replace(".", "_") + part[part.rfind(".") :]) for part in k): v
200+
for k, v in results.items()
201+
}
202+
)
203+
204+
205+
base.Parser.parse = _parse
206+
base.Parser.__alias_shadowed_imports = __alias_shadowed_imports
207+
208+
logger.info("Patched Parser.parse method.")
209+
210+
patch_parse()
211+
212+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# generated by fastapi-codegen:
2+
# filename: shadowed_imports.yaml
3+
# timestamp: 2020-06-19T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from fastapi import FastAPI
8+
9+
app = FastAPI(
10+
title='REST API',
11+
version='0.0.1',
12+
servers=[{'url': 'https://api.something.com/1'}],
13+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# generated by fastapi-codegen:
2+
# filename: shadowed_imports.yaml
3+
# timestamp: 2020-06-19T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from datetime import date as date_aliased
8+
from typing import Optional
9+
10+
from pydantic import BaseModel, Field
11+
12+
13+
class MarketingOptIn(BaseModel):
14+
optedIn: Optional[bool] = Field(None, example=False)
15+
date: Optional[date_aliased] = Field(None, example='2018-04-26T17:03:25.155Z')
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
openapi: 3.0.0
2+
info:
3+
title: REST API
4+
version: 0.0.1
5+
servers:
6+
- url: https://api.something.com/1
7+
components:
8+
schemas:
9+
marketingOptIn:
10+
type: object
11+
properties:
12+
optedIn:
13+
type: boolean
14+
example: false
15+
date:
16+
type: string
17+
format: date
18+
example: '2018-04-26T17:03:25.155Z'
19+
paths: {}

0 commit comments

Comments
 (0)