|
1 | | -# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors |
2 | | -# |
3 | | -# SPDX-License-Identifier: Apache-2.0 |
4 | | -import logging |
5 | | -from itertools import groupby |
6 | | -from pathlib import Path |
7 | | -from typing import NamedTuple |
8 | | - |
9 | | -from datamodel_code_generator.format import CodeFormatter |
10 | | -from datamodel_code_generator.imports import IMPORT_ANNOTATIONS, Import, Imports |
11 | | -from datamodel_code_generator.model.base import DataModel |
12 | | -from datamodel_code_generator.parser import base |
13 | | -from datamodel_code_generator.reference import ModelResolver |
14 | | - |
15 | | -logger = logging.getLogger(__name__) |
16 | | - |
17 | | -# Save the original method before patching |
18 | | -original_parse = base.Parser.parse |
19 | | - |
20 | | - |
21 | | -def patch_parse() -> None: # noqa: C901 |
22 | | - def __alias_shadowed_imports( |
23 | | - self: base.Parser, |
24 | | - models: list[DataModel], |
25 | | - all_model_field_names: set[str], |
26 | | - ) -> None: |
27 | | - for model in models: |
28 | | - for model_field in model.fields: |
29 | | - if model_field.data_type.type in all_model_field_names: |
30 | | - alias = model_field.data_type.type + "_aliased" |
31 | | - model_field.data_type.type = alias |
32 | | - if model_field.data_type.import_: |
33 | | - model_field.data_type.import_.alias = alias |
34 | | - |
35 | | - def _parse( # noqa: PLR0912, PLR0914, PLR0915 |
36 | | - self: base.Parser, |
37 | | - with_import: bool | None = True, # noqa: FBT001, FBT002 |
38 | | - format_: bool | None = True, # noqa: FBT001, FBT002 |
39 | | - settings_path: Path | None = None, |
40 | | - ) -> str | dict[tuple[str, ...], base.Result]: |
41 | | - self.parse_raw() |
42 | | - |
43 | | - if with_import: |
44 | | - self.imports.append(IMPORT_ANNOTATIONS) |
45 | | - |
46 | | - if format_: |
47 | | - code_formatter: CodeFormatter | None = CodeFormatter( |
48 | | - self.target_python_version, |
49 | | - settings_path, |
50 | | - self.wrap_string_literal, |
51 | | - skip_string_normalization=not self.use_double_quotes, |
52 | | - known_third_party=self.known_third_party, |
53 | | - custom_formatters=self.custom_formatter, |
54 | | - custom_formatters_kwargs=self.custom_formatters_kwargs, |
55 | | - encoding=self.encoding, |
56 | | - formatters=self.formatters, |
57 | | - ) |
58 | | - else: |
59 | | - code_formatter = None |
60 | | - |
61 | | - _, sorted_data_models, require_update_action_models = base.sort_data_models( |
62 | | - self.results |
63 | | - ) |
64 | | - |
65 | | - results: dict[tuple[str, ...], base.Result] = {} |
66 | | - |
67 | | - def module_key(data_model: DataModel) -> tuple[str, ...]: |
68 | | - return tuple(data_model.module_path) |
69 | | - |
70 | | - def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]: |
71 | | - return (len(data_model.module_path), tuple(data_model.module_path)) |
72 | | - |
73 | | - # process in reverse order to correctly establish module levels |
74 | | - grouped_models = groupby( |
75 | | - sorted(sorted_data_models.values(), key=sort_key, reverse=True), |
76 | | - key=module_key, |
77 | | - ) |
78 | | - |
79 | | - module_models: list[tuple[tuple[str, ...], list[DataModel]]] = [] |
80 | | - unused_models: list[DataModel] = [] |
81 | | - model_to_module_models: dict[ |
82 | | - DataModel, tuple[tuple[str, ...], list[DataModel]] |
83 | | - ] = {} |
84 | | - module_to_import: dict[tuple[str, ...], Imports] = {} |
85 | | - |
86 | | - previous_module: tuple[str, ...] = () |
87 | | - for module, models in ((k, [*v]) for k, v in grouped_models): |
88 | | - for model in models: |
89 | | - model_to_module_models[model] = module, models |
90 | | - self._Parser__delete_duplicate_models(models) |
91 | | - self._Parser__replace_duplicate_name_in_module(models) |
92 | | - if len(previous_module) - len(module) > 1: |
93 | | - module_models.extend( |
94 | | - ( |
95 | | - previous_module[:parts], |
96 | | - [], |
97 | | - ) |
98 | | - for parts in range(len(previous_module) - 1, len(module), -1) |
99 | | - ) |
100 | | - module_models.append( |
101 | | - ( |
102 | | - module, |
103 | | - models, |
104 | | - ) |
105 | | - ) |
106 | | - previous_module = module |
107 | | - |
108 | | - class Processed(NamedTuple): |
109 | | - module: tuple[str, ...] |
110 | | - models: list[DataModel] |
111 | | - init: bool |
112 | | - imports: Imports |
113 | | - scoped_model_resolver: ModelResolver |
114 | | - |
115 | | - processed_models: list[Processed] = [] |
116 | | - |
117 | | - for module_, models in module_models: |
118 | | - imports = module_to_import[module_] = Imports(self.use_exact_imports) |
119 | | - init = False |
120 | | - if module_: |
121 | | - parent = (*module_[:-1], "__init__.py") |
122 | | - if parent not in results: |
123 | | - results[parent] = base.Result(body="") |
124 | | - if (*module_, "__init__.py") in results: |
125 | | - module = (*module_, "__init__.py") |
126 | | - init = True |
127 | | - else: |
128 | | - module = tuple( |
129 | | - part.replace("-", "_") |
130 | | - for part in (*module_[:-1], f"{module_[-1]}.py") |
131 | | - ) |
132 | | - else: |
133 | | - module = ("__init__.py",) |
134 | | - |
135 | | - all_module_fields = { |
136 | | - field.name |
137 | | - for model in models |
138 | | - for field in model.fields |
139 | | - if field.name is not None |
140 | | - } |
141 | | - scoped_model_resolver = ModelResolver(exclude_names=all_module_fields) |
142 | | - |
143 | | - self.__alias_shadowed_imports(models, all_module_fields) |
144 | | - self._Parser__override_required_field(models) |
145 | | - self._Parser__replace_unique_list_to_set(models) |
146 | | - self._Parser__change_from_import( |
147 | | - models, imports, scoped_model_resolver, init |
148 | | - ) |
149 | | - self._Parser__extract_inherited_enum(models) |
150 | | - self._Parser__set_reference_default_value_to_field(models) |
151 | | - self._Parser__reuse_model(models, require_update_action_models) |
152 | | - self._Parser__collapse_root_models( |
153 | | - models, unused_models, imports, scoped_model_resolver |
154 | | - ) |
155 | | - self._Parser__set_default_enum_member(models) |
156 | | - self._Parser__sort_models(models, imports) |
157 | | - self._Parser__change_field_name(models) |
158 | | - self._Parser__apply_discriminator_type(models, imports) |
159 | | - self._Parser__set_one_literal_on_default(models) |
160 | | - |
161 | | - processed_models.append( |
162 | | - Processed(module, models, init, imports, scoped_model_resolver) |
163 | | - ) |
164 | | - |
165 | | - for processed_model in processed_models: |
166 | | - for model in processed_model.models: |
167 | | - processed_model.imports.append(model.imports) |
168 | | - |
169 | | - for unused_model in unused_models: |
170 | | - module, models = model_to_module_models[unused_model] |
171 | | - if unused_model in models: # pragma: no cover |
172 | | - imports = module_to_import[module] |
173 | | - imports.remove(unused_model.imports) |
174 | | - models.remove(unused_model) |
175 | | - |
176 | | - for processed_model in processed_models: |
177 | | - # postprocess imports to remove unused imports. |
178 | | - model_code = str("\n".join([str(m) for m in processed_model.models])) |
179 | | - unused_imports = [ |
180 | | - (from_, import_) |
181 | | - for from_, imports_ in processed_model.imports.items() |
182 | | - for import_ in imports_ |
183 | | - if import_ not in model_code |
184 | | - ] |
185 | | - for from_, import_ in unused_imports: |
186 | | - processed_model.imports.remove(Import(from_=from_, import_=import_)) |
187 | | - |
188 | | - for ( |
189 | | - module, |
190 | | - models, |
191 | | - init, |
192 | | - imports, |
193 | | - scoped_model_resolver, |
194 | | - ) in processed_models: # noqa: B007 |
195 | | - # process after removing unused models |
196 | | - self._Parser__change_imported_model_name( |
197 | | - models, imports, scoped_model_resolver |
198 | | - ) |
199 | | - |
200 | | - for ( |
201 | | - module, |
202 | | - models, |
203 | | - init, |
204 | | - imports, |
205 | | - scoped_model_resolver, |
206 | | - ) in processed_models: # noqa: B007 |
207 | | - result: list[str] = [] |
208 | | - if models: |
209 | | - if with_import: |
210 | | - result += [str(self.imports), str(imports), "\n"] |
211 | | - |
212 | | - code = base.dump_templates(models) |
213 | | - result += [code] |
214 | | - |
215 | | - if self.dump_resolve_reference_action is not None: |
216 | | - result += [ |
217 | | - "\n", |
218 | | - self.dump_resolve_reference_action( |
219 | | - m.reference.short_name |
220 | | - for m in models |
221 | | - if m.path in require_update_action_models |
222 | | - ), |
223 | | - ] |
224 | | - if not result and not init: |
225 | | - continue |
226 | | - body = "\n".join(result) |
227 | | - if code_formatter: |
228 | | - body = code_formatter.format_code(body) |
229 | | - |
230 | | - results[module] = base.Result( |
231 | | - body=body, source=models[0].file_path if models else None |
232 | | - ) |
233 | | - |
234 | | - # retain existing behaviour |
235 | | - if [*results] == [("__init__.py",)]: |
236 | | - return results["__init__.py",].body |
237 | | - |
238 | | - results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()} |
239 | | - return ( |
240 | | - self._Parser__postprocess_result_modules(results) |
241 | | - if self.treat_dot_as_module |
242 | | - else { |
243 | | - tuple( |
244 | | - ( |
245 | | - part[: part.rfind(".")].replace(".", "_") |
246 | | - + part[part.rfind(".") :] |
247 | | - ) |
248 | | - for part in k |
249 | | - ): v |
250 | | - for k, v in results.items() |
251 | | - } |
252 | | - ) |
253 | | - |
254 | | - base.Parser.parse = _parse |
255 | | - base.Parser.__alias_shadowed_imports = __alias_shadowed_imports |
256 | | - logger.info("Patched Parser.parse method.") |
257 | | - |
| 1 | +from .patches import patch_parse |
258 | 2 |
|
259 | 3 | patch_parse() |
0 commit comments