Skip to content

Commit 3369c5e

Browse files
committed
Fix regex_engine config not applied to RootModel generic
1 parent 3e7b16b commit 3369c5e

27 files changed

+863
-31
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,12 @@ lint.per-file-ignores."tests/**/*.py" = [
186186
"INP001", # no implicit namespace
187187
"PLC0415", # local imports in tests are fine
188188
"PLC2701", # private import is fine
189+
"PLR0904", # too many public methods in test classes is fine
189190
"PLR0913", # as many arguments as want
190191
"PLR0915", # can have longer test methods
191192
"PLR0917", # as many arguments as want
192193
"PLR2004", # Magic value used in comparison, consider replacing with a constant variable
194+
"PLR6301", # test methods don't need to use self
193195
"S", # no safety concerns
194196
"SLF001", # can test private methods
195197
]

src/datamodel_code_generator/model/base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def _build_union_type_hint(self) -> str | None:
220220
return f"Union[{', '.join(parts)}]"
221221
return None # pragma: no cover
222222

223+
def _build_base_union_type_hint(self) -> str | None: # pragma: no cover
224+
"""Build Union[] base type hint from data_type.data_types if forward reference requires it."""
225+
if not (self._use_union_operator != self.data_type.use_union_operator and self.data_type.is_union):
226+
return None
227+
parts = [dt.base_type_hint for dt in self.data_type.data_types if dt.base_type_hint]
228+
if len(parts) > 1:
229+
return f"Union[{', '.join(parts)}]"
230+
return None
231+
223232
@property
224233
def type_hint(self) -> str: # noqa: PLR0911
225234
"""Get the type hint string for this field, including nullability."""
@@ -241,6 +250,33 @@ def type_hint(self) -> str: # noqa: PLR0911
241250
return get_optional_type(type_hint, self._use_union_operator)
242251
return type_hint
243252

253+
@property
254+
def base_type_hint(self) -> str:
255+
"""Get the base type hint without constrained type kwargs.
256+
257+
This returns the type without kwargs (e.g., 'str' instead of 'constr(pattern=...)').
258+
Used in RootModel generics when regex_engine config is needed for lookaround patterns.
259+
"""
260+
base_hint = self._build_base_union_type_hint() or self.data_type.base_type_hint
261+
262+
if not base_hint: # pragma: no cover
263+
return NONE
264+
265+
needs_optional = (
266+
(self.nullable is True)
267+
or (self.required and self.type_has_null)
268+
or (self.nullable is None and not self.required and self.fall_back_to_nullable)
269+
)
270+
skip_optional = (
271+
self.has_default_factory
272+
or (self.data_type.is_optional and self.data_type.type != ANY)
273+
or (self.nullable is False)
274+
)
275+
276+
if needs_optional and not skip_optional: # pragma: no cover
277+
return get_optional_type(base_hint, self._use_union_operator)
278+
return base_hint
279+
244280
@property
245281
def imports(self) -> tuple[Import, ...]:
246282
"""Get all imports required for this field's type hint."""

src/datamodel_code_generator/model/pydantic_v2/base_model.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,8 @@ def __init__( # noqa: PLR0913
228228
config_parameters["arbitrary_types_allowed"] = True
229229
break
230230

231-
for field in self.fields:
232-
# Check if a regex pattern uses lookarounds.
233-
# Depending on the generation configuration, the pattern may end up in two different places.
234-
pattern = (isinstance(field.constraints, Constraints) and field.constraints.pattern) or (
235-
field.data_type.kwargs or {}
236-
).get("pattern")
237-
if pattern and re.search(r"\(\?<?[=!]", pattern):
238-
config_parameters["regex_engine"] = '"python-re"'
239-
break
231+
if self._has_lookaround_pattern():
232+
config_parameters["regex_engine"] = '"python-re"'
240233

241234
if isinstance(self.extra_template_data.get("config"), dict):
242235
for key, value in self.extra_template_data["config"].items():
@@ -265,3 +258,16 @@ def _get_config_extra(self) -> Literal["'allow'", "'forbid'", "'ignore'"] | None
265258
elif additional_properties is False:
266259
config_extra = "'forbid'"
267260
return config_extra
261+
262+
def _has_lookaround_pattern(self) -> bool:
263+
"""Check if any field has a regex pattern with lookaround assertions."""
264+
lookaround_regex = re.compile(r"\(\?<?[=!]")
265+
for field in self.fields:
266+
pattern = isinstance(field.constraints, Constraints) and field.constraints.pattern
267+
if pattern and lookaround_regex.search(pattern):
268+
return True
269+
for data_type in field.data_type.all_data_types:
270+
pattern = (data_type.kwargs or {}).get("pattern")
271+
if pattern and lookaround_regex.search(pattern):
272+
return True
273+
return False

src/datamodel_code_generator/model/template/pydantic_v2/RootModel.jinja2

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
{%- macro get_type_hint(_fields) -%}
1+
{%- macro get_type_hint(_fields, use_base_type) -%}
22
{%- if _fields -%}
33
{#There will only ever be a single field for RootModel#}
4+
{%- if use_base_type -%}
5+
{{- _fields[0].base_type_hint}}
6+
{%- else -%}
47
{{- _fields[0].type_hint}}
58
{%- endif -%}
9+
{%- endif -%}
610
{%- endmacro -%}
711

812

913
{% for decorator in decorators -%}
1014
{{ decorator }}
1115
{% endfor -%}
1216

13-
class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
17+
{#- Use base_type_hint in generic when regex_engine is set to avoid evaluating pattern before config is processed -#}
18+
{%- set use_base_type = config and config.regex_engine -%}
19+
class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields, use_base_type)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
1420
{%- if description %}
1521
"""
1622
{{ description | indent(4) }}

src/datamodel_code_generator/types.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,11 @@ def full_name(self) -> str:
422422

423423
@property
424424
def all_data_types(self) -> Iterator[DataType]:
425-
"""Recursively yield all nested DataTypes including self."""
425+
"""Recursively yield all nested DataTypes including self and dict_key."""
426426
for data_type in self.data_types:
427427
yield from data_type.all_data_types
428+
if self.dict_key:
429+
yield from self.dict_key.all_data_types
428430
yield self
429431

430432
def find_source(self, source_type: type[SourceT]) -> SourceT | None:
@@ -618,6 +620,124 @@ def is_union(self) -> bool:
618620
"""Return whether this DataType represents a union of multiple types."""
619621
return len(self.data_types) > 1
620622

623+
# Mapping from constrained type functions to their base Python types.
624+
# Only constr is included because it's the only type with a 'pattern' parameter
625+
# that can trigger lookaround regex detection. Other constrained types (conint,
626+
# confloat, condecimal, conbytes) don't have pattern constraints, so they will
627+
# never need base_type_hint conversion in the regex_engine context.
628+
_CONSTRAINED_TYPE_TO_BASE: ClassVar[dict[str, str]] = {
629+
"constr": "str",
630+
}
631+
632+
@property
633+
def base_type_hint(self) -> str: # noqa: PLR0912, PLR0915
634+
"""Return the base type hint without constrained type kwargs.
635+
636+
For types like constr(pattern=..., min_length=...), this returns just 'str'.
637+
This works recursively for nested types like list[constr(pattern=...)] -> list[str].
638+
639+
This is useful when the pattern contains lookaround assertions that require
640+
regex_engine="python-re", which must be set in model_config. In such cases,
641+
the RootModel generic cannot use the constrained type because it would be
642+
evaluated at class definition time before model_config is processed.
643+
"""
644+
if self.is_func and self.kwargs:
645+
type_: str | None = self.alias or self.type
646+
if type_: # pragma: no branch
647+
base_type = self._CONSTRAINED_TYPE_TO_BASE.get(type_)
648+
if base_type is None:
649+
# Not a constrained type we convert (e.g., conint, confloat)
650+
# Return the full type_hint with kwargs to avoid returning bare function name
651+
return self.type_hint
652+
if self.is_optional and base_type != ANY: # pragma: no cover
653+
return get_optional_type(base_type, self.use_union_operator)
654+
return base_type
655+
656+
type_: str | None = self.alias or self.type
657+
if not type_:
658+
if self.is_tuple: # pragma: no cover
659+
tuple_type = STANDARD_TUPLE if self.use_standard_collections else TUPLE
660+
inner_types = [item.base_type_hint or ANY for item in self.data_types]
661+
type_ = f"{tuple_type}[{', '.join(inner_types)}]" if inner_types else f"{tuple_type}[()]"
662+
elif self.is_union:
663+
data_types: list[str] = []
664+
for data_type in self.data_types:
665+
data_type_type = data_type.base_type_hint
666+
if not data_type_type or data_type_type in data_types: # pragma: no cover
667+
continue
668+
669+
if data_type_type == NONE:
670+
self.is_optional = True
671+
continue
672+
673+
non_optional_data_type_type = _remove_none_from_union(
674+
data_type_type, use_union_operator=self.use_union_operator
675+
)
676+
677+
if non_optional_data_type_type != data_type_type: # pragma: no cover
678+
self.is_optional = True
679+
680+
data_types.append(non_optional_data_type_type)
681+
if not data_types: # pragma: no cover
682+
type_ = ANY
683+
self.import_ = self.import_ or IMPORT_ANY
684+
elif len(data_types) == 1:
685+
type_ = data_types[0]
686+
elif self.use_union_operator:
687+
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
688+
else: # pragma: no cover
689+
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
690+
elif len(self.data_types) == 1:
691+
type_ = self.data_types[0].base_type_hint
692+
elif self.enum_member_literals: # pragma: no cover
693+
parts = [f"{enum_class}.{member}" for enum_class, member in self.enum_member_literals]
694+
type_ = f"{LITERAL}[{', '.join(parts)}]"
695+
elif self.literals: # pragma: no cover
696+
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
697+
elif self.reference: # pragma: no cover
698+
type_ = self.reference.short_name
699+
type_ = self._get_wrapped_reference_type_hint(type_)
700+
else: # pragma: no cover
701+
type_ = ""
702+
if self.reference: # pragma: no cover
703+
source = self.reference.source
704+
if isinstance(source, Nullable) and source.nullable:
705+
self.is_optional = True
706+
if self.is_list:
707+
if self.use_generic_container:
708+
list_ = SEQUENCE
709+
elif self.use_standard_collections:
710+
list_ = STANDARD_LIST
711+
else: # pragma: no cover
712+
list_ = LIST
713+
type_ = f"{list_}[{type_}]" if type_ else list_
714+
elif self.is_set: # pragma: no cover
715+
if self.use_generic_container:
716+
set_ = STANDARD_FROZEN_SET if self.use_standard_collections else FROZEN_SET
717+
elif self.use_standard_collections:
718+
set_ = STANDARD_SET
719+
else:
720+
set_ = SET
721+
type_ = f"{set_}[{type_}]" if type_ else set_
722+
elif self.is_dict:
723+
if self.use_generic_container:
724+
dict_ = MAPPING
725+
elif self.use_standard_collections:
726+
dict_ = STANDARD_DICT
727+
else: # pragma: no cover
728+
dict_ = DICT
729+
if self.dict_key or type_:
730+
key = self.dict_key.base_type_hint if self.dict_key else STR
731+
type_ = f"{dict_}[{key}, {type_ or ANY}]"
732+
else: # pragma: no cover
733+
type_ = dict_
734+
735+
if self.is_optional and type_ != ANY:
736+
return get_optional_type(type_, self.use_union_operator)
737+
if self.is_func: # pragma: no cover
738+
return f"{type_}()"
739+
return type_
740+
621741

622742
DataTypeT = TypeVar("DataTypeT", bound=DataType)
623743

tests/cli_doc/test_cli_doc_coverage.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def collected_options(collection_data: dict[str, Any]) -> set[str]: # pragma: n
4646
class TestCLIDocCoverage: # pragma: no cover
4747
"""Documentation coverage tests."""
4848

49-
def test_all_options_have_cli_doc_markers( # noqa: PLR6301
50-
self, collected_options: set[str]
51-
) -> None:
49+
def test_all_options_have_cli_doc_markers(self, collected_options: set[str]) -> None:
5250
"""Verify that all CLI options (except MANUAL_DOCS) have cli_doc markers."""
5351
all_options = get_all_canonical_options()
5452
documentable_options = all_options - MANUAL_DOCS
@@ -60,7 +58,7 @@ def test_all_options_have_cli_doc_markers( # noqa: PLR6301
6058
+ "\n\nAdd @pytest.mark.cli_doc(...) to tests for these options."
6159
)
6260

63-
def test_meta_options_not_manual(self) -> None: # noqa: PLR6301
61+
def test_meta_options_not_manual(self) -> None:
6462
"""Verify that CLI_OPTION_META options are not in MANUAL_DOCS."""
6563
meta_options = set(CLI_OPTION_META.keys())
6664
overlap = meta_options & MANUAL_DOCS
@@ -70,9 +68,7 @@ def test_meta_options_not_manual(self) -> None: # noqa: PLR6301
7068
+ "\n".join(f" - {opt}" for opt in sorted(overlap))
7169
)
7270

73-
def test_collection_schema_version( # noqa: PLR6301
74-
self, collection_data: dict[str, Any]
75-
) -> None:
71+
def test_collection_schema_version(self, collection_data: dict[str, Any]) -> None:
7672
"""Verify that collection data has expected schema version."""
7773
version = collection_data.get("schema_version")
7874
assert version is not None, "Collection data missing 'schema_version'"
@@ -83,7 +79,7 @@ class TestCoverageStats: # pragma: no cover
8379
"""Informational tests for coverage statistics."""
8480

8581
@pytest.mark.skip(reason="Informational: run with -v --no-skip to see stats")
86-
def test_show_coverage_stats(self, collected_options: set[str]) -> None: # noqa: PLR6301
82+
def test_show_coverage_stats(self, collected_options: set[str]) -> None:
8783
"""Display documentation coverage statistics."""
8884
all_options = get_all_canonical_options()
8985
documentable = all_options - MANUAL_DOCS
@@ -94,9 +90,7 @@ def test_show_coverage_stats(self, collected_options: set[str]) -> None: # noqa
9490
print(f" {opt}") # noqa: T201
9591

9692
@pytest.mark.skip(reason="Informational: run with -v --no-skip to see stats")
97-
def test_show_documented_options( # noqa: PLR6301
98-
self, collected_options: set[str]
99-
) -> None:
93+
def test_show_documented_options(self, collected_options: set[str]) -> None:
10094
"""Display currently documented options."""
10195
print(f"\nDocumented options ({len(collected_options)}):") # noqa: T201
10296
for opt in sorted(collected_options):

tests/cli_doc/test_cli_options_sync.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_get_canonical_option() -> None:
3131
class TestCLIOptionMetaSync: # pragma: no cover
3232
"""Synchronization tests for CLI_OPTION_META."""
3333

34-
def test_all_registered_options_exist_in_argparse(self) -> None: # noqa: PLR6301
34+
def test_all_registered_options_exist_in_argparse(self) -> None:
3535
"""Verify that all options in CLI_OPTION_META exist in argparse."""
3636
argparse_options = get_all_canonical_options()
3737
registered = set(CLI_OPTION_META.keys())
@@ -44,7 +44,7 @@ def test_all_registered_options_exist_in_argparse(self) -> None: # noqa: PLR630
4444
+ "\n\nRemove them from CLI_OPTION_META or add them to arguments.py."
4545
)
4646

47-
def test_manual_doc_options_exist_in_argparse(self) -> None: # noqa: PLR6301
47+
def test_manual_doc_options_exist_in_argparse(self) -> None:
4848
"""Verify that all options in MANUAL_DOCS exist in argparse."""
4949
argparse_options = get_all_canonical_options()
5050

@@ -56,7 +56,7 @@ def test_manual_doc_options_exist_in_argparse(self) -> None: # noqa: PLR6301
5656
+ "\n\nRemove them from MANUAL_DOCS or add them to arguments.py."
5757
)
5858

59-
def test_no_overlap_between_meta_and_manual(self) -> None: # noqa: PLR6301
59+
def test_no_overlap_between_meta_and_manual(self) -> None:
6060
"""Verify that CLI_OPTION_META and MANUAL_DOCS don't overlap."""
6161
overlap = set(CLI_OPTION_META.keys()) & MANUAL_DOCS
6262
if overlap:
@@ -66,7 +66,7 @@ def test_no_overlap_between_meta_and_manual(self) -> None: # noqa: PLR6301
6666
+ "\n\nAn option should be in one or the other, not both."
6767
)
6868

69-
def test_meta_names_match_keys(self) -> None: # noqa: PLR6301
69+
def test_meta_names_match_keys(self) -> None:
7070
"""Verify that CLIOptionMeta.name matches the dict key."""
7171
mismatches = []
7272
for key, meta in CLI_OPTION_META.items():
@@ -76,7 +76,7 @@ def test_meta_names_match_keys(self) -> None: # noqa: PLR6301
7676
if mismatches:
7777
pytest.fail("CLIOptionMeta.name mismatches:\n" + "\n".join(mismatches))
7878

79-
def test_all_argparse_options_are_documented_or_excluded(self) -> None: # noqa: PLR6301
79+
def test_all_argparse_options_are_documented_or_excluded(self) -> None:
8080
"""Verify that all argparse options are either documented or explicitly excluded.
8181
8282
This test fails when a new CLI option is added to arguments.py
@@ -96,7 +96,7 @@ def test_all_argparse_options_are_documented_or_excluded(self) -> None: # noqa:
9696
"or add to MANUAL_DOCS if they should have manual documentation."
9797
)
9898

99-
def test_canonical_option_determination_is_stable(self) -> None: # noqa: PLR6301
99+
def test_canonical_option_determination_is_stable(self) -> None:
100100
"""Verify that canonical option determination is deterministic.
101101
102102
The canonical option should be the longest option string for each action.

0 commit comments

Comments
 (0)