|
| 1 | +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +import logging |
| 5 | +from pathlib import Path |
| 6 | +from typing import Union |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | +from datamodel_code_generator.parser import base |
| 11 | +from datamodel_code_generator.model.base import DataModel |
| 12 | + |
| 13 | +# Save the original method before patching |
| 14 | +original_parse = base.Parser.parse |
| 15 | + |
| 16 | +def patch_parse() -> None: # noqa: C901 |
| 17 | + def __alias_shadowed_imports( |
| 18 | + self, |
| 19 | + models: list[DataModel], |
| 20 | + all_model_field_names: set[str], |
| 21 | + ) -> None: |
| 22 | + for model in models: |
| 23 | + for model_field in model.fields: |
| 24 | + if model_field.data_type.type in all_model_field_names: |
| 25 | + alias = model_field.data_type.type + "_aliased" |
| 26 | + model_field.data_type.type = alias |
| 27 | + model_field.data_type.import_.alias = alias |
| 28 | + |
| 29 | + def _parse( # noqa: PLR0912, PLR0914, PLR0915 |
| 30 | + self, |
| 31 | + with_import: bool | None = True, # noqa: FBT001, FBT002 |
| 32 | + format_: bool | None = True, # noqa: FBT001, FBT002 |
| 33 | + settings_path: Path | None = None, |
| 34 | + ) -> str | dict[tuple[str, ...], base.Result]: |
| 35 | + self.parse_raw() |
| 36 | + |
| 37 | + if with_import: |
| 38 | + self.imports.append(base.IMPORT_ANNOTATIONS) |
| 39 | + |
| 40 | + if format_: |
| 41 | + code_formatter: base.CodeFormatter | None = base.CodeFormatter( |
| 42 | + self.target_python_version, |
| 43 | + settings_path, |
| 44 | + self.wrap_string_literal, |
| 45 | + skip_string_normalization=not self.use_double_quotes, |
| 46 | + known_third_party=self.known_third_party, |
| 47 | + custom_formatters=self.custom_formatter, |
| 48 | + custom_formatters_kwargs=self.custom_formatters_kwargs, |
| 49 | + encoding=self.encoding, |
| 50 | + formatters=self.formatters, |
| 51 | + ) |
| 52 | + else: |
| 53 | + code_formatter = None |
| 54 | + |
| 55 | + _, sorted_data_models, require_update_action_models = base.sort_data_models(self.results) |
| 56 | + |
| 57 | + results: dict[tuple[str, ...], base.Result] = {} |
| 58 | + |
| 59 | + def module_key(data_model: DataModel) -> tuple[str, ...]: |
| 60 | + return tuple(data_model.module_path) |
| 61 | + |
| 62 | + def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]: |
| 63 | + return (len(data_model.module_path), tuple(data_model.module_path)) |
| 64 | + |
| 65 | + # process in reverse order to correctly establish module levels |
| 66 | + grouped_models = base.groupby( |
| 67 | + sorted(sorted_data_models.values(), key=sort_key, reverse=True), |
| 68 | + key=module_key, |
| 69 | + ) |
| 70 | + |
| 71 | + module_models: list[tuple[tuple[str, ...], list[DataModel]]] = [] |
| 72 | + unused_models: list[DataModel] = [] |
| 73 | + model_to_module_models: dict[DataModel, tuple[tuple[str, ...], list[DataModel]]] = {} |
| 74 | + module_to_import: dict[tuple[str, ...], base.Imports] = {} |
| 75 | + |
| 76 | + previous_module: tuple[str, ...] = () |
| 77 | + for module, models in ((k, [*v]) for k, v in grouped_models): |
| 78 | + for model in models: |
| 79 | + model_to_module_models[model] = module, models |
| 80 | + self._Parser__delete_duplicate_models(models) |
| 81 | + self._Parser__replace_duplicate_name_in_module(models) |
| 82 | + if len(previous_module) - len(module) > 1: |
| 83 | + module_models.extend( |
| 84 | + ( |
| 85 | + previous_module[:parts], |
| 86 | + [], |
| 87 | + ) |
| 88 | + for parts in range(len(previous_module) - 1, len(module), -1) |
| 89 | + ) |
| 90 | + module_models.append(( |
| 91 | + module, |
| 92 | + models, |
| 93 | + )) |
| 94 | + previous_module = module |
| 95 | + |
| 96 | + class Processed(base.NamedTuple): |
| 97 | + module: tuple[str, ...] |
| 98 | + models: list[DataModel] |
| 99 | + init: bool |
| 100 | + imports: base.Imports |
| 101 | + scoped_model_resolver: base.ModelResolver |
| 102 | + |
| 103 | + processed_models: list[Processed] = [] |
| 104 | + |
| 105 | + for module_, models in module_models: |
| 106 | + imports = module_to_import[module_] = base.Imports(self.use_exact_imports) |
| 107 | + init = False |
| 108 | + if module_: |
| 109 | + parent = (*module_[:-1], "__init__.py") |
| 110 | + if parent not in results: |
| 111 | + results[parent] = base.Result(body="") |
| 112 | + if (*module_, "__init__.py") in results: |
| 113 | + module = (*module_, "__init__.py") |
| 114 | + init = True |
| 115 | + else: |
| 116 | + module = tuple(part.replace("-", "_") for part in (*module_[:-1], f"{module_[-1]}.py")) |
| 117 | + else: |
| 118 | + module = ("__init__.py",) |
| 119 | + |
| 120 | + all_module_fields = {field.name for model in models for field in model.fields if field.name is not None} |
| 121 | + scoped_model_resolver = base.ModelResolver(exclude_names=all_module_fields) |
| 122 | + |
| 123 | + self.__alias_shadowed_imports(models, all_module_fields) |
| 124 | + self._Parser__override_required_field(models) |
| 125 | + self._Parser__replace_unique_list_to_set(models) |
| 126 | + self._Parser__change_from_import(models, imports, scoped_model_resolver, init) |
| 127 | + self._Parser__extract_inherited_enum(models) |
| 128 | + self._Parser__set_reference_default_value_to_field(models) |
| 129 | + self._Parser__reuse_model(models, require_update_action_models) |
| 130 | + self._Parser__collapse_root_models(models, unused_models, imports, scoped_model_resolver) |
| 131 | + self._Parser__set_default_enum_member(models) |
| 132 | + self._Parser__sort_models(models, imports) |
| 133 | + self._Parser__change_field_name(models) |
| 134 | + self._Parser__apply_discriminator_type(models, imports) |
| 135 | + self._Parser__set_one_literal_on_default(models) |
| 136 | + |
| 137 | + processed_models.append(Processed(module, models, init, imports, scoped_model_resolver)) |
| 138 | + |
| 139 | + for processed_model in processed_models: |
| 140 | + for model in processed_model.models: |
| 141 | + processed_model.imports.append(model.imports) |
| 142 | + |
| 143 | + for unused_model in unused_models: |
| 144 | + module, models = model_to_module_models[unused_model] |
| 145 | + if unused_model in models: # pragma: no cover |
| 146 | + imports = module_to_import[module] |
| 147 | + imports.remove(unused_model.imports) |
| 148 | + models.remove(unused_model) |
| 149 | + |
| 150 | + for processed_model in processed_models: |
| 151 | + # postprocess imports to remove unused imports. |
| 152 | + model_code = str("\n".join([str(m) for m in processed_model.models])) |
| 153 | + unused_imports = [ |
| 154 | + (from_, import_) |
| 155 | + for from_, imports_ in processed_model.imports.items() |
| 156 | + for import_ in imports_ |
| 157 | + if import_ not in model_code |
| 158 | + ] |
| 159 | + for from_, import_ in unused_imports: |
| 160 | + processed_model.imports.remove(Import(from_=from_, import_=import_)) |
| 161 | + |
| 162 | + for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007 |
| 163 | + # process after removing unused models |
| 164 | + self._Parser__change_imported_model_name(models, imports, scoped_model_resolver) |
| 165 | + |
| 166 | + for module, models, init, imports, scoped_model_resolver in processed_models: # noqa: B007 |
| 167 | + result: list[str] = [] |
| 168 | + if models: |
| 169 | + if with_import: |
| 170 | + result += [str(self.imports), str(imports), "\n"] |
| 171 | + |
| 172 | + code = base.dump_templates(models) |
| 173 | + result += [code] |
| 174 | + |
| 175 | + if self.dump_resolve_reference_action is not None: |
| 176 | + result += [ |
| 177 | + "\n", |
| 178 | + self.dump_resolve_reference_action( |
| 179 | + m.reference.short_name for m in models if m.path in require_update_action_models |
| 180 | + ), |
| 181 | + ] |
| 182 | + if not result and not init: |
| 183 | + continue |
| 184 | + body = "\n".join(result) |
| 185 | + if code_formatter: |
| 186 | + body = code_formatter.format_code(body) |
| 187 | + |
| 188 | + results[module] = base.Result(body=body, source=models[0].file_path if models else None) |
| 189 | + |
| 190 | + # retain existing behaviour |
| 191 | + if [*results] == [("__init__.py",)]: |
| 192 | + return results["__init__.py",].body |
| 193 | + |
| 194 | + results = {tuple(i.replace("-", "_") for i in k): v for k, v in results.items()} |
| 195 | + return ( |
| 196 | + self._Parser__postprocess_result_modules(results) |
| 197 | + if self.treat_dot_as_module |
| 198 | + else { |
| 199 | + tuple((part[: part.rfind(".")].replace(".", "_") + part[part.rfind(".") :]) for part in k): v |
| 200 | + for k, v in results.items() |
| 201 | + } |
| 202 | + ) |
| 203 | + |
| 204 | + |
| 205 | + base.Parser.parse = _parse |
| 206 | + base.Parser.__alias_shadowed_imports = __alias_shadowed_imports |
| 207 | + |
| 208 | + logger.info("Patched Parser.parse method.") |
| 209 | + |
| 210 | +patch_parse() |
| 211 | + |
| 212 | + |
0 commit comments