22#
33# SPDX-License-Identifier: Apache-2.0
44import logging
5+ from itertools import groupby
56from pathlib import Path
6- from typing import Union
7+ from typing import NamedTuple
78
8- logger = logging .getLogger (__name__ )
9-
10- from datamodel_code_generator .parser import base
9+ from datamodel_code_generator .format import CodeFormatter
10+ from datamodel_code_generator .imports import IMPORT_ANNOTATIONS , Import , Imports
1111from datamodel_code_generator .model .base import DataModel
12+ from datamodel_code_generator .parser import base
13+ from datamodel_code_generator .reference import ModelResolver
14+
15+ logger = logging .getLogger (__name__ )
1216
1317# Save the original method before patching
1418original_parse = base .Parser .parse
1519
20+
1621def patch_parse () -> None : # noqa: C901
1722 def __alias_shadowed_imports (
18- self ,
23+ self : base . Parser ,
1924 models : list [DataModel ],
2025 all_model_field_names : set [str ],
2126 ) -> None :
@@ -24,21 +29,22 @@ def __alias_shadowed_imports(
2429 if model_field .data_type .type in all_model_field_names :
2530 alias = model_field .data_type .type + "_aliased"
2631 model_field .data_type .type = alias
27- model_field .data_type .import_ .alias = alias
32+ if model_field .data_type .import_ :
33+ model_field .data_type .import_ .alias = alias
2834
2935 def _parse ( # noqa: PLR0912, PLR0914, PLR0915
30- self ,
36+ self : base . Parser ,
3137 with_import : bool | None = True , # noqa: FBT001, FBT002
3238 format_ : bool | None = True , # noqa: FBT001, FBT002
3339 settings_path : Path | None = None ,
3440 ) -> str | dict [tuple [str , ...], base .Result ]:
3541 self .parse_raw ()
3642
3743 if with_import :
38- self .imports .append (base . IMPORT_ANNOTATIONS )
44+ self .imports .append (IMPORT_ANNOTATIONS )
3945
4046 if format_ :
41- code_formatter : base . CodeFormatter | None = base . CodeFormatter (
47+ code_formatter : CodeFormatter | None = CodeFormatter (
4248 self .target_python_version ,
4349 settings_path ,
4450 self .wrap_string_literal ,
@@ -52,7 +58,9 @@ def _parse( # noqa: PLR0912, PLR0914, PLR0915
5258 else :
5359 code_formatter = None
5460
55- _ , sorted_data_models , require_update_action_models = base .sort_data_models (self .results )
61+ _ , sorted_data_models , require_update_action_models = base .sort_data_models (
62+ self .results
63+ )
5664
5765 results : dict [tuple [str , ...], base .Result ] = {}
5866
@@ -63,15 +71,17 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
6371 return (len (data_model .module_path ), tuple (data_model .module_path ))
6472
6573 # process in reverse order to correctly establish module levels
66- grouped_models = base . groupby (
74+ grouped_models = groupby (
6775 sorted (sorted_data_models .values (), key = sort_key , reverse = True ),
6876 key = module_key ,
6977 )
7078
7179 module_models : list [tuple [tuple [str , ...], list [DataModel ]]] = []
7280 unused_models : list [DataModel ] = []
73- model_to_module_models : dict [DataModel , tuple [tuple [str , ...], list [DataModel ]]] = {}
74- module_to_import : dict [tuple [str , ...], base .Imports ] = {}
81+ model_to_module_models : dict [
82+ DataModel , tuple [tuple [str , ...], list [DataModel ]]
83+ ] = {}
84+ module_to_import : dict [tuple [str , ...], Imports ] = {}
7585
7686 previous_module : tuple [str , ...] = ()
7787 for module , models in ((k , [* v ]) for k , v in grouped_models ):
@@ -87,23 +97,25 @@ def sort_key(data_model: DataModel) -> tuple[int, tuple[str, ...]]:
8797 )
8898 for parts in range (len (previous_module ) - 1 , len (module ), - 1 )
8999 )
90- module_models .append ((
91- module ,
92- models ,
93- ))
100+ module_models .append (
101+ (
102+ module ,
103+ models ,
104+ )
105+ )
94106 previous_module = module
95107
96- class Processed (base . NamedTuple ):
108+ class Processed (NamedTuple ):
97109 module : tuple [str , ...]
98110 models : list [DataModel ]
99111 init : bool
100- imports : base . Imports
101- scoped_model_resolver : base . ModelResolver
112+ imports : Imports
113+ scoped_model_resolver : ModelResolver
102114
103115 processed_models : list [Processed ] = []
104116
105117 for module_ , models in module_models :
106- imports = module_to_import [module_ ] = base . Imports (self .use_exact_imports )
118+ imports = module_to_import [module_ ] = Imports (self .use_exact_imports )
107119 init = False
108120 if module_ :
109121 parent = (* module_ [:- 1 ], "__init__.py" )
@@ -113,28 +125,42 @@ class Processed(base.NamedTuple):
113125 module = (* module_ , "__init__.py" )
114126 init = True
115127 else :
116- module = tuple (part .replace ("-" , "_" ) for part in (* module_ [:- 1 ], f"{ module_ [- 1 ]} .py" ))
128+ module = tuple (
129+ part .replace ("-" , "_" )
130+ for part in (* module_ [:- 1 ], f"{ module_ [- 1 ]} .py" )
131+ )
117132 else :
118133 module = ("__init__.py" ,)
119134
120- all_module_fields = {field .name for model in models for field in model .fields if field .name is not None }
121- scoped_model_resolver = base .ModelResolver (exclude_names = all_module_fields )
135+ all_module_fields = {
136+ field .name
137+ for model in models
138+ for field in model .fields
139+ if field .name is not None
140+ }
141+ scoped_model_resolver = ModelResolver (exclude_names = all_module_fields )
122142
123143 self .__alias_shadowed_imports (models , all_module_fields )
124144 self ._Parser__override_required_field (models )
125145 self ._Parser__replace_unique_list_to_set (models )
126- self ._Parser__change_from_import (models , imports , scoped_model_resolver , init )
146+ self ._Parser__change_from_import (
147+ models , imports , scoped_model_resolver , init
148+ )
127149 self ._Parser__extract_inherited_enum (models )
128150 self ._Parser__set_reference_default_value_to_field (models )
129151 self ._Parser__reuse_model (models , require_update_action_models )
130- self ._Parser__collapse_root_models (models , unused_models , imports , scoped_model_resolver )
152+ self ._Parser__collapse_root_models (
153+ models , unused_models , imports , scoped_model_resolver
154+ )
131155 self ._Parser__set_default_enum_member (models )
132156 self ._Parser__sort_models (models , imports )
133157 self ._Parser__change_field_name (models )
134158 self ._Parser__apply_discriminator_type (models , imports )
135159 self ._Parser__set_one_literal_on_default (models )
136160
137- processed_models .append (Processed (module , models , init , imports , scoped_model_resolver ))
161+ processed_models .append (
162+ Processed (module , models , init , imports , scoped_model_resolver )
163+ )
138164
139165 for processed_model in processed_models :
140166 for model in processed_model .models :
@@ -159,11 +185,25 @@ class Processed(base.NamedTuple):
159185 for from_ , import_ in unused_imports :
160186 processed_model .imports .remove (Import (from_ = from_ , import_ = import_ ))
161187
162- for module , models , init , imports , scoped_model_resolver in processed_models : # noqa: B007
188+ for (
189+ module ,
190+ models ,
191+ init ,
192+ imports ,
193+ scoped_model_resolver ,
194+ ) in processed_models : # noqa: B007
163195 # process after removing unused models
164- self ._Parser__change_imported_model_name (models , imports , scoped_model_resolver )
196+ self ._Parser__change_imported_model_name (
197+ models , imports , scoped_model_resolver
198+ )
165199
166- for module , models , init , imports , scoped_model_resolver in processed_models : # noqa: B007
200+ for (
201+ module ,
202+ models ,
203+ init ,
204+ imports ,
205+ scoped_model_resolver ,
206+ ) in processed_models : # noqa: B007
167207 result : list [str ] = []
168208 if models :
169209 if with_import :
@@ -176,7 +216,9 @@ class Processed(base.NamedTuple):
176216 result += [
177217 "\n " ,
178218 self .dump_resolve_reference_action (
179- m .reference .short_name for m in models if m .path in require_update_action_models
219+ m .reference .short_name
220+ for m in models
221+ if m .path in require_update_action_models
180222 ),
181223 ]
182224 if not result and not init :
@@ -185,7 +227,9 @@ class Processed(base.NamedTuple):
185227 if code_formatter :
186228 body = code_formatter .format_code (body )
187229
188- results [module ] = base .Result (body = body , source = models [0 ].file_path if models else None )
230+ results [module ] = base .Result (
231+ body = body , source = models [0 ].file_path if models else None
232+ )
189233
190234 # retain existing behaviour
191235 if [* results ] == [("__init__.py" ,)]:
@@ -196,17 +240,20 @@ class Processed(base.NamedTuple):
196240 self ._Parser__postprocess_result_modules (results )
197241 if self .treat_dot_as_module
198242 else {
199- tuple ((part [: part .rfind ("." )].replace ("." , "_" ) + part [part .rfind ("." ) :]) for part in k ): v
243+ tuple (
244+ (
245+ part [: part .rfind ("." )].replace ("." , "_" )
246+ + part [part .rfind ("." ) :]
247+ )
248+ for part in k
249+ ): v
200250 for k , v in results .items ()
201251 }
202252 )
203253
204-
205254 base .Parser .parse = _parse
206255 base .Parser .__alias_shadowed_imports = __alias_shadowed_imports
207-
208256 logger .info ("Patched Parser.parse method." )
209257
210- patch_parse ()
211-
212258
259+ patch_parse ()
0 commit comments