Skip to content

Commit f1cc623

Browse files
committed
Refactor __init__ and patch_parse function
1 parent 342bf37 commit f1cc623

File tree

2 files changed

+254
-257
lines changed

2 files changed

+254
-257
lines changed

fastapi_code_generator/__init__.py

Lines changed: 1 addition & 257 deletions
Original file line numberDiff line numberDiff line change
@@ -1,259 +1,3 @@
1-
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2-
#
3-
# SPDX-License-Identifier: Apache-2.0
4-
import logging
5-
from itertools import groupby
6-
from pathlib import Path
7-
from typing import NamedTuple
8-
9-
from datamodel_code_generator.format import CodeFormatter
10-
from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports
11-
from datamodel_code_generator.model.base import DataModel
12-
from datamodel_code_generator.parser import base
13-
from datamodel_code_generator.reference import ModelResolver
14-
15-
logger = logging.getLogger(__name__)
16-
17-
# Save the original method before patching
18-
original_parse = base.Parser.parse
19-
20-
21-
def patch_parse() -> None: # noqa: C901
22-
def __alias_shadowed_imports(
23-
self: base.Parser,
24-
models: list[DataModel],
25-
all_model_field_names: set[str],
26-
) -> None:
27-
for model in models:
28-
for model_field in model.fields:
29-
if model_field.data_type.type in all_model_field_names:
30-
alias = model_field.data_type.type + "_aliased"
31-
model_field.data_type.type = alias
32-
if model_field.data_type.import_:
33-
model_field.data_type.import_.alias = alias
34-
35-
def _parse( # noqa: PLR0912, PLR0914, PLR0915
36-
self: base.Parser,
37-
with_import: bool | None = True, # noqa: FBT001, FBT002
38-
format_: bool | None = True, # noqa: FBT001, FBT002
39-
settings_path: Path | None = None,
40-
) -> str | dict[tuple[str, ...], base.Result]:
41-
self.parse_raw()
42-
43-
if with_import:
44-
self.imports.append(IMPORT_ANNOTATIONS)
45-
46-
if format_:
47-
code_formatter: CodeFormatter | None = CodeFormatter(
48-
self.target_python_version,
49-
settings_path,
50-
self.wrap_string_literal,
51-
skip_string_normalization=not self.use_double_quotes,
52-
known_third_party=self.known_third_party,
53-
custom_formatters=self.custom_formatter,
54-
custom_formatters_kwargs=self.custom_formatters_kwargs,
55-
encoding=self.encoding,
56-
formatters=self.formatters,
57-
)
58-
else:
59-
code_formatter = None
60-
61-
_, sorted_data_models, require_update_action_models = base.sort_data_models(
62-
self.results
63-
)
64-
65-
results: dict[tuple[str, ...], base.Result] = {}
66-
67-
def module_key(data_model: DataModel) -> tuple[str, ...]:
68-
return tuple(data_model.module_path)
69-
70-
def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
71-
return (len(data_model.module_path), tuple(data_model.module_path))
72-
73-
# process in reverse order to correctly establish module levels
74-
grouped_models = groupby(
75-
sorted(sorted_data_models.values(), key=sort_key, reverse=True),
76-
key=module_key,
77-
)
78-
79-
module_models: list[tuple[tuple[str, ...], list[DataModel]]] = []
80-
unused_models: list[DataModel] = []
81-
model_to_module_models: dict[
82-
DataModel, tuple[tuple[str, ...], list[DataModel]]
83-
] = {}
84-
module_to_import: dict[tuple[str, ...], Imports] = {}
85-
86-
previous_module: tuple[str, ...] = ()
87-
for module, models in ((k, [*v]) for k, v in grouped_models):
88-
for model in models:
89-
model_to_module_models[model] = module, models
90-
self._Parser__delete_duplicate_models(models)
91-
self._Parser__replace_duplicate_name_in_module(models)
92-
if len(previous_module) - len(module) > 1:
93-
module_models.extend(
94-
(
95-
previous_module[:parts],
96-
[],
97-
)
98-
for parts in range(len(previous_module) - 1, len(module), -1)
99-
)
100-
module_models.append(
101-
(
102-
module,
103-
models,
104-
)
105-
)
106-
previous_module = module
107-
108-
class Processed(NamedTuple):
109-
module: tuple[str, ...]
110-
models: list[DataModel]
111-
init: bool
112-
imports: Imports
113-
scoped_model_resolver: ModelResolver
114-
115-
processed_models: list[Processed] = []
116-
117-
for module_, models in module_models:
118-
imports = module_to_import[module_] = Imports(self.use_exact_imports)
119-
init = False
120-
if module_:
121-
parent = (*module_[:-1], "__init__.py")
122-
if parent not in results:
123-
results[parent] = base.Result(body="")
124-
if (*module_, "__init__.py") in results:
125-
module = (*module_, "__init__.py")
126-
init = True
127-
else:
128-
module = tuple(
129-
part.replace("-", "_")
130-
for part in (*module_[:-1], f"{module_[-1]}.py")
131-
)
132-
else:
133-
module = ("__init__.py",)
134-
135-
all_module_fields = {
136-
field.name
137-
for model in models
138-
for field in model.fields
139-
if field.name is not None
140-
}
141-
scoped_model_resolver = ModelResolver(exclude_names=all_module_fields)
142-
143-
self.__alias_shadowed_imports(models, all_module_fields)
144-
self._Parser__override_required_field(models)
145-
self._Parser__replace_unique_list_to_set(models)
146-
self._Parser__change_from_import(
147-
models, imports, scoped_model_resolver, init
148-
)
149-
self._Parser__extract_inherited_enum(models)
150-
self._Parser__set_reference_default_value_to_field(models)
151-
self._Parser__reuse_model(models, require_update_action_models)
152-
self._Parser__collapse_root_models(
153-
models, unused_models, imports, scoped_model_resolver
154-
)
155-
self._Parser__set_default_enum_member(models)
156-
self._Parser__sort_models(models, imports)
157-
self._Parser__change_field_name(models)
158-
self._Parser__apply_discriminator_type(models, imports)
159-
self._Parser__set_one_literal_on_default(models)
160-
161-
processed_models.append(
162-
Processed(module, models, init, imports, scoped_model_resolver)
163-
)
164-
165-
for processed_model in processed_models:
166-
for model in processed_model.models:
167-
processed_model.imports.append(model.imports)
168-
169-
for unused_model in unused_models:
170-
module, models = model_to_module_models[unused_model]
171-
if unused_model in models: # pragma: no cover
172-
imports = module_to_import[module]
173-
imports.remove(unused_model.imports)
174-
models.remove(unused_model)
175-
176-
for processed_model in processed_models:
177-
# postprocess imports to remove unused imports.
178-
model_code = str("\n".join([str(m) for m in processed_model.models]))
179-
unused_imports = [
180-
(from_, import_)
181-
for from_, imports_ in processed_model.imports.items()
182-
for import_ in imports_
183-
if import_ not in model_code
184-
]
185-
for from_, import_ in unused_imports:
186-
processed_model.imports.remove(Import(from_=from_, import_=import_))
187-
188-
for (
189-
module,
190-
models,
191-
init,
192-
imports,
193-
scoped_model_resolver,
194-
) in processed_models: # noqa: B007
195-
# process after removing unused models
196-
self._Parser__change_imported_model_name(
197-
models, imports, scoped_model_resolver
198-
)
199-
200-
for (
201-
module,
202-
models,
203-
init,
204-
imports,
205-
scoped_model_resolver,
206-
) in processed_models: # noqa: B007
207-
result: list[str] = []
208-
if models:
209-
if with_import:
210-
result += [str(self.imports), str(imports), "\n"]
211-
212-
code = base.dump_templates(models)
213-
result += [code]
214-
215-
if self.dump_resolve_reference_action is not None:
216-
result += [
217-
"\n",
218-
self.dump_resolve_reference_action(
219-
m.reference.short_name
220-
for m in models
221-
if m.path in require_update_action_models
222-
),
223-
]
224-
if not result and not init:
225-
continue
226-
body = "\n".join(result)
227-
if code_formatter:
228-
body = code_formatter.format_code(body)
229-
230-
results[module] = base.Result(
231-
body=body, source=models[0].file_path if models else None
232-
)
233-
234-
# retain existing behaviour
235-
if [*results] == [("__init__.py",)]:
236-
return results["__init__.py",].body
237-
238-
results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()}
239-
return (
240-
self._Parser__postprocess_result_modules(results)
241-
if self.treat_dot_as_module
242-
else {
243-
tuple(
244-
(
245-
part[: part.rfind(".")].replace(".", "_")
246-
+ part[part.rfind(".") :]
247-
)
248-
for part in k
249-
): v
250-
for k, v in results.items()
251-
}
252-
)
253-
254-
base.Parser.parse = _parse
255-
base.Parser.__alias_shadowed_imports = __alias_shadowed_imports
256-
logger.info("Patched Parser.parse method.")
257-
1+
from .patches import patch_parse
2582

2593
patch_parse()

0 commit comments

Comments
 (0)