Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,12 @@ lint.per-file-ignores."tests/**/*.py" = [
"INP001", # no implicit namespace
"PLC0415", # local imports in tests are fine
"PLC2701", # private import is fine
"PLR0904", # too many public methods in test classes is fine
"PLR0913", # as many arguments as want
"PLR0915", # can have longer test methods
"PLR0917", # as many arguments as want
"PLR2004", # Magic value used in comparison, consider replacing with a constant variable
"PLR6301", # test methods don't need to use self
"S", # no safety concerns
"SLF001", # can test private methods
]
Expand Down
36 changes: 36 additions & 0 deletions src/datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def _build_union_type_hint(self) -> str | None:
return f"Union[{', '.join(parts)}]"
return None # pragma: no cover

def _build_base_union_type_hint(self) -> str | None: # pragma: no cover
"""Build Union[] base type hint from data_type.data_types if forward reference requires it."""
if not (self._use_union_operator != self.data_type.use_union_operator and self.data_type.is_union):
return None
parts = [dt.base_type_hint for dt in self.data_type.data_types if dt.base_type_hint]
if len(parts) > 1:
return f"Union[{', '.join(parts)}]"
return None

@property
def type_hint(self) -> str: # noqa: PLR0911
"""Get the type hint string for this field, including nullability."""
Expand All @@ -241,6 +250,33 @@ def type_hint(self) -> str: # noqa: PLR0911
return get_optional_type(type_hint, self._use_union_operator)
return type_hint

@property
def base_type_hint(self) -> str:
"""Get the base type hint without constrained type kwargs.

This returns the type without kwargs (e.g., 'str' instead of 'constr(pattern=...)').
Used in RootModel generics when regex_engine config is needed for lookaround patterns.
"""
base_hint = self._build_base_union_type_hint() or self.data_type.base_type_hint

if not base_hint: # pragma: no cover
return NONE

needs_optional = (
(self.nullable is True)
or (self.required and self.type_has_null)
or (self.nullable is None and not self.required and self.fall_back_to_nullable)
)
skip_optional = (
self.has_default_factory
or (self.data_type.is_optional and self.data_type.type != ANY)
or (self.nullable is False)
)

if needs_optional and not skip_optional: # pragma: no cover
return get_optional_type(base_hint, self._use_union_operator)
return base_hint

@property
def imports(self) -> tuple[Import, ...]:
"""Get all imports required for this field's type hint."""
Expand Down
24 changes: 15 additions & 9 deletions src/datamodel_code_generator/model/pydantic_v2/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,8 @@ def __init__( # noqa: PLR0913
config_parameters["arbitrary_types_allowed"] = True
break

for field in self.fields:
# Check if a regex pattern uses lookarounds.
# Depending on the generation configuration, the pattern may end up in two different places.
pattern = (isinstance(field.constraints, Constraints) and field.constraints.pattern) or (
field.data_type.kwargs or {}
).get("pattern")
if pattern and re.search(r"\(\?<?[=!]", pattern):
config_parameters["regex_engine"] = '"python-re"'
break
if self._has_lookaround_pattern():
config_parameters["regex_engine"] = '"python-re"'

if isinstance(self.extra_template_data.get("config"), dict):
for key, value in self.extra_template_data["config"].items():
Expand Down Expand Up @@ -265,3 +258,16 @@ def _get_config_extra(self) -> Literal["'allow'", "'forbid'", "'ignore'"] | None
elif additional_properties is False:
config_extra = "'forbid'"
return config_extra

def _has_lookaround_pattern(self) -> bool:
"""Check if any field has a regex pattern with lookaround assertions."""
lookaround_regex = re.compile(r"\(\?<?[=!]")
for field in self.fields:
pattern = isinstance(field.constraints, Constraints) and field.constraints.pattern
if pattern and lookaround_regex.search(pattern):
return True
for data_type in field.data_type.all_data_types:
pattern = (data_type.kwargs or {}).get("pattern")
if pattern and lookaround_regex.search(pattern):
return True
return False
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
{%- macro get_type_hint(_fields) -%}
{%- macro get_type_hint(_fields, use_base_type) -%}
{%- if _fields -%}
{#There will only ever be a single field for RootModel#}
{%- if use_base_type -%}
{{- _fields[0].base_type_hint}}
{%- else -%}
{{- _fields[0].type_hint}}
{%- endif -%}
{%- endif -%}
{%- endmacro -%}


{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}

class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
{#- Use base_type_hint in generic when regex_engine is set to avoid evaluating pattern before config is processed -#}
{%- set use_base_type = config and config.regex_engine -%}
class {{ class_name }}({{ base_class }}{%- if fields -%}[{{get_type_hint(fields, use_base_type)}}]{%- endif -%}):{% if comment is defined %} # {{ comment }}{% endif %}
{%- if description %}
"""
{{ description | indent(4) }}
Expand Down
122 changes: 121 additions & 1 deletion src/datamodel_code_generator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,11 @@ def full_name(self) -> str:

@property
def all_data_types(self) -> Iterator[DataType]:
"""Recursively yield all nested DataTypes including self."""
"""Recursively yield all nested DataTypes including self and dict_key."""
for data_type in self.data_types:
yield from data_type.all_data_types
if self.dict_key:
yield from self.dict_key.all_data_types
yield self

def find_source(self, source_type: type[SourceT]) -> SourceT | None:
Expand Down Expand Up @@ -618,6 +620,124 @@ def is_union(self) -> bool:
"""Return whether this DataType represents a union of multiple types."""
return len(self.data_types) > 1

# Mapping from constrained type functions to their base Python types.
# Only constr is included because it's the only type with a 'pattern' parameter
# that can trigger lookaround regex detection. Other constrained types (conint,
# confloat, condecimal, conbytes) don't have pattern constraints, so they will
# never need base_type_hint conversion in the regex_engine context.
_CONSTRAINED_TYPE_TO_BASE: ClassVar[dict[str, str]] = {
"constr": "str",
}

@property
def base_type_hint(self) -> str: # noqa: PLR0912, PLR0915
"""Return the base type hint without constrained type kwargs.

For types like constr(pattern=..., min_length=...), this returns just 'str'.
This works recursively for nested types like list[constr(pattern=...)] -> list[str].

This is useful when the pattern contains lookaround assertions that require
regex_engine="python-re", which must be set in model_config. In such cases,
the RootModel generic cannot use the constrained type because it would be
evaluated at class definition time before model_config is processed.
"""
if self.is_func and self.kwargs:
type_: str | None = self.alias or self.type
if type_: # pragma: no branch
base_type = self._CONSTRAINED_TYPE_TO_BASE.get(type_)
if base_type is None:
# Not a constrained type we convert (e.g., conint, confloat)
# Return the full type_hint with kwargs to avoid returning bare function name
return self.type_hint
if self.is_optional and base_type != ANY: # pragma: no cover
return get_optional_type(base_type, self.use_union_operator)
return base_type

type_: str | None = self.alias or self.type
if not type_:
if self.is_tuple: # pragma: no cover
tuple_type = STANDARD_TUPLE if self.use_standard_collections else TUPLE
inner_types = [item.base_type_hint or ANY for item in self.data_types]
type_ = f"{tuple_type}[{', '.join(inner_types)}]" if inner_types else f"{tuple_type}[()]"
elif self.is_union:
data_types: list[str] = []
for data_type in self.data_types:
data_type_type = data_type.base_type_hint
if not data_type_type or data_type_type in data_types: # pragma: no cover
continue

if data_type_type == NONE:
self.is_optional = True
continue

non_optional_data_type_type = _remove_none_from_union(
data_type_type, use_union_operator=self.use_union_operator
)

if non_optional_data_type_type != data_type_type: # pragma: no cover
self.is_optional = True

data_types.append(non_optional_data_type_type)
if not data_types: # pragma: no cover
type_ = ANY
self.import_ = self.import_ or IMPORT_ANY
elif len(data_types) == 1:
type_ = data_types[0]
elif self.use_union_operator:
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
else: # pragma: no cover
type_ = f"{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]"
elif len(self.data_types) == 1:
type_ = self.data_types[0].base_type_hint
elif self.enum_member_literals: # pragma: no cover
parts = [f"{enum_class}.{member}" for enum_class, member in self.enum_member_literals]
type_ = f"{LITERAL}[{', '.join(parts)}]"
elif self.literals: # pragma: no cover
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
elif self.reference: # pragma: no cover
type_ = self.reference.short_name
type_ = self._get_wrapped_reference_type_hint(type_)
else: # pragma: no cover
type_ = ""
if self.reference: # pragma: no cover
source = self.reference.source
if isinstance(source, Nullable) and source.nullable:
self.is_optional = True
if self.is_list:
if self.use_generic_container:
list_ = SEQUENCE
elif self.use_standard_collections:
list_ = STANDARD_LIST
else: # pragma: no cover
list_ = LIST
type_ = f"{list_}[{type_}]" if type_ else list_
elif self.is_set: # pragma: no cover
if self.use_generic_container:
set_ = STANDARD_FROZEN_SET if self.use_standard_collections else FROZEN_SET
elif self.use_standard_collections:
set_ = STANDARD_SET
else:
set_ = SET
type_ = f"{set_}[{type_}]" if type_ else set_
elif self.is_dict:
if self.use_generic_container:
dict_ = MAPPING
elif self.use_standard_collections:
dict_ = STANDARD_DICT
else: # pragma: no cover
dict_ = DICT
if self.dict_key or type_:
key = self.dict_key.base_type_hint if self.dict_key else STR
type_ = f"{dict_}[{key}, {type_ or ANY}]"
else: # pragma: no cover
type_ = dict_

if self.is_optional and type_ != ANY:
return get_optional_type(type_, self.use_union_operator)
if self.is_func: # pragma: no cover
return f"{type_}()"
return type_


DataTypeT = TypeVar("DataTypeT", bound=DataType)

Expand Down
16 changes: 5 additions & 11 deletions tests/cli_doc/test_cli_doc_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def collected_options(collection_data: dict[str, Any]) -> set[str]: # pragma: n
class TestCLIDocCoverage: # pragma: no cover
"""Documentation coverage tests."""

def test_all_options_have_cli_doc_markers( # noqa: PLR6301
self, collected_options: set[str]
) -> None:
def test_all_options_have_cli_doc_markers(self, collected_options: set[str]) -> None:
"""Verify that all CLI options (except MANUAL_DOCS) have cli_doc markers."""
all_options = get_all_canonical_options()
documentable_options = all_options - MANUAL_DOCS
Expand All @@ -60,7 +58,7 @@ def test_all_options_have_cli_doc_markers( # noqa: PLR6301
+ "\n\nAdd @pytest.mark.cli_doc(...) to tests for these options."
)

def test_meta_options_not_manual(self) -> None: # noqa: PLR6301
def test_meta_options_not_manual(self) -> None:
"""Verify that CLI_OPTION_META options are not in MANUAL_DOCS."""
meta_options = set(CLI_OPTION_META.keys())
overlap = meta_options & MANUAL_DOCS
Expand All @@ -70,9 +68,7 @@ def test_meta_options_not_manual(self) -> None: # noqa: PLR6301
+ "\n".join(f" - {opt}" for opt in sorted(overlap))
)

def test_collection_schema_version( # noqa: PLR6301
self, collection_data: dict[str, Any]
) -> None:
def test_collection_schema_version(self, collection_data: dict[str, Any]) -> None:
"""Verify that collection data has expected schema version."""
version = collection_data.get("schema_version")
assert version is not None, "Collection data missing 'schema_version'"
Expand All @@ -83,7 +79,7 @@ class TestCoverageStats: # pragma: no cover
"""Informational tests for coverage statistics."""

@pytest.mark.skip(reason="Informational: run with -v --no-skip to see stats")
def test_show_coverage_stats(self, collected_options: set[str]) -> None: # noqa: PLR6301
def test_show_coverage_stats(self, collected_options: set[str]) -> None:
"""Display documentation coverage statistics."""
all_options = get_all_canonical_options()
documentable = all_options - MANUAL_DOCS
Expand All @@ -94,9 +90,7 @@ def test_show_coverage_stats(self, collected_options: set[str]) -> None: # noqa
print(f" {opt}") # noqa: T201

@pytest.mark.skip(reason="Informational: run with -v --no-skip to see stats")
def test_show_documented_options( # noqa: PLR6301
self, collected_options: set[str]
) -> None:
def test_show_documented_options(self, collected_options: set[str]) -> None:
"""Display currently documented options."""
print(f"\nDocumented options ({len(collected_options)}):") # noqa: T201
for opt in sorted(collected_options):
Expand Down
12 changes: 6 additions & 6 deletions tests/cli_doc/test_cli_options_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_get_canonical_option() -> None:
class TestCLIOptionMetaSync: # pragma: no cover
"""Synchronization tests for CLI_OPTION_META."""

def test_all_registered_options_exist_in_argparse(self) -> None: # noqa: PLR6301
def test_all_registered_options_exist_in_argparse(self) -> None:
"""Verify that all options in CLI_OPTION_META exist in argparse."""
argparse_options = get_all_canonical_options()
registered = set(CLI_OPTION_META.keys())
Expand All @@ -44,7 +44,7 @@ def test_all_registered_options_exist_in_argparse(self) -> None: # noqa: PLR630
+ "\n\nRemove them from CLI_OPTION_META or add them to arguments.py."
)

def test_manual_doc_options_exist_in_argparse(self) -> None: # noqa: PLR6301
def test_manual_doc_options_exist_in_argparse(self) -> None:
"""Verify that all options in MANUAL_DOCS exist in argparse."""
argparse_options = get_all_canonical_options()

Expand All @@ -56,7 +56,7 @@ def test_manual_doc_options_exist_in_argparse(self) -> None: # noqa: PLR6301
+ "\n\nRemove them from MANUAL_DOCS or add them to arguments.py."
)

def test_no_overlap_between_meta_and_manual(self) -> None: # noqa: PLR6301
def test_no_overlap_between_meta_and_manual(self) -> None:
"""Verify that CLI_OPTION_META and MANUAL_DOCS don't overlap."""
overlap = set(CLI_OPTION_META.keys()) & MANUAL_DOCS
if overlap:
Expand All @@ -66,7 +66,7 @@ def test_no_overlap_between_meta_and_manual(self) -> None: # noqa: PLR6301
+ "\n\nAn option should be in one or the other, not both."
)

def test_meta_names_match_keys(self) -> None: # noqa: PLR6301
def test_meta_names_match_keys(self) -> None:
"""Verify that CLIOptionMeta.name matches the dict key."""
mismatches = []
for key, meta in CLI_OPTION_META.items():
Expand All @@ -76,7 +76,7 @@ def test_meta_names_match_keys(self) -> None: # noqa: PLR6301
if mismatches:
pytest.fail("CLIOptionMeta.name mismatches:\n" + "\n".join(mismatches))

def test_all_argparse_options_are_documented_or_excluded(self) -> None: # noqa: PLR6301
def test_all_argparse_options_are_documented_or_excluded(self) -> None:
"""Verify that all argparse options are either documented or explicitly excluded.

This test fails when a new CLI option is added to arguments.py
Expand All @@ -96,7 +96,7 @@ def test_all_argparse_options_are_documented_or_excluded(self) -> None: # noqa:
"or add to MANUAL_DOCS if they should have manual documentation."
)

def test_canonical_option_determination_is_stable(self) -> None: # noqa: PLR6301
def test_canonical_option_determination_is_stable(self) -> None:
"""Verify that canonical option determination is deterministic.

The canonical option should be the longest option string for each action.
Expand Down
Loading
Loading