Skip to content

Commit d41a19b

Browse files
authored
fix: wrap RootModel primitive defaults with default_factory (#2714)
1 parent 92a32c0 commit d41a19b

File tree

13 files changed

+87
-45
lines changed

13 files changed

+87
-45
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Internal types for model module."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any
7+
8+
9+
@dataclass(repr=False)
10+
class WrappedDefault:
11+
"""Represents a default value wrapped with its type constructor."""
12+
13+
value: Any
14+
type_name: str
15+
16+
def __repr__(self) -> str:
17+
"""Return type constructor representation, e.g., 'CountType(10)'."""
18+
return f"{self.type_name}({self.value!r})"

src/datamodel_code_generator/model/base.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from abc import ABC, abstractmethod
1111
from collections import defaultdict
1212
from copy import deepcopy
13-
from dataclasses import dataclass
1413
from functools import cached_property, lru_cache
1514
from pathlib import Path
1615
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar, Union
@@ -26,6 +25,7 @@
2625
IMPORT_UNION,
2726
Import,
2827
)
28+
from datamodel_code_generator.model._types import WrappedDefault
2929
from datamodel_code_generator.reference import Reference, _BaseModel
3030
from datamodel_code_generator.types import (
3131
ANY,
@@ -39,6 +39,8 @@
3939
)
4040
from datamodel_code_generator.util import PYDANTIC_V2, ConfigDict
4141

42+
__all__ = ["WrappedDefault"]
43+
4244
if TYPE_CHECKING:
4345
from collections.abc import Iterator
4446

@@ -113,18 +115,6 @@ def merge_constraints(a: ConstraintsBaseT | None, b: ConstraintsBaseT | None) ->
113115
})
114116

115117

116-
@dataclass(repr=False)
117-
class WrappedDefault:
118-
"""Represents a default value wrapped with its type constructor."""
119-
120-
value: Any
121-
type_name: str
122-
123-
def __repr__(self) -> str:
124-
"""Return type constructor representation, e.g., 'CountType(10)'."""
125-
return f"{self.type_name}({self.value!r})"
126-
127-
128118
class DataModelFieldBase(_BaseModel):
129119
"""Base class for model field representation and rendering."""
130120

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DataModel,
1818
DataModelFieldBase,
1919
)
20+
from datamodel_code_generator.model._types import WrappedDefault
2021
from datamodel_code_generator.model.base import UNDEFINED
2122
from datamodel_code_generator.model.pydantic.imports import (
2223
IMPORT_ANYURL,
@@ -122,6 +123,8 @@ def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any
122123
return int(value)
123124

124125
def _get_default_as_pydantic_model(self) -> str | None:
126+
if isinstance(self.default, WrappedDefault):
127+
return f"lambda :{self.default!r}"
125128
for data_type in self.data_type.data_types or (self.data_type,):
126129
# TODO: Check nested data_types
127130
if data_type.is_dict:
@@ -141,15 +144,18 @@ def _get_default_as_pydantic_model(self) -> str | None:
141144
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
142145
)
143146
elif data_type.reference and isinstance(data_type.reference.source, BaseModelBase):
147+
source = data_type.reference.source
148+
is_root_model = hasattr(source, "BASE_CLASS") and source.BASE_CLASS == "pydantic.RootModel"
144149
if self.data_type.is_union:
145150
if not isinstance(self.default, (dict, list)):
151+
if not is_root_model:
152+
continue
153+
elif isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
146154
continue
147-
if isinstance(self.default, dict) and any(dt.is_dict for dt in self.data_type.data_types):
148-
continue
149-
return (
150-
f"lambda :{data_type.alias or data_type.reference.source.class_name}."
151-
f"{self._PARSE_METHOD}({self.default!r})"
152-
)
155+
class_name = data_type.alias or source.class_name
156+
if is_root_model:
157+
return f"lambda :{class_name}({self.default!r})"
158+
return f"lambda :{class_name}.{self._PARSE_METHOD}({self.default!r})"
153159
return None
154160

155161
def _process_data_in_str(self, data: dict[str, Any]) -> None:

tests/data/expected/main/jsonschema/root_model_default_value.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,7 @@ class NameType(RootModel[str]):
2525

2626
class Model(BaseModel):
2727
admin_state: AdminStateLeaf | None = AdminStateLeaf.enable
28-
count: Annotated[
29-
CountType | None,
30-
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
31-
]
28+
count: Annotated[CountType | None, Field(default_factory=lambda: CountType(10))]
3229
name: Annotated[
33-
NameType | None,
34-
Field(
35-
default_factory=lambda: NameType.model_validate(NameType('default_name'))
36-
),
30+
NameType | None, Field(default_factory=lambda: NameType('default_name'))
3731
]

tests/data/expected/main/jsonschema/root_model_default_value_branches.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ class CountType(RootModel[int]):
1515

1616
class Model(BaseModel):
1717
count_with_default: Annotated[
18-
CountType | None,
19-
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
18+
CountType | None, Field(default_factory=lambda: CountType(10))
2019
]
2120
count_no_default: CountType | None = None
2221
count_list_default: Annotated[

tests/data/expected/main/jsonschema/root_model_default_value_no_annotated.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,5 @@ class NameType(RootModel[constr(min_length=1, max_length=50)]):
2424

2525
class Model(BaseModel):
2626
admin_state: AdminStateLeaf | None = AdminStateLeaf.enable
27-
count: CountType | None = Field(
28-
default_factory=lambda: CountType.model_validate(10)
29-
)
30-
name: NameType | None = Field(
31-
default_factory=lambda: NameType.model_validate('default_name')
32-
)
27+
count: CountType | None = Field(default_factory=lambda: CountType(10))
28+
name: NameType | None = Field(default_factory=lambda: NameType('default_name'))

tests/data/expected/main/jsonschema/root_model_default_value_non_root.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ class PersonType(BaseModel):
1919

2020
class Model(BaseModel):
2121
root_model_field: Annotated[
22-
CountType | None,
23-
Field(default_factory=lambda: CountType.model_validate(CountType(10))),
22+
CountType | None, Field(default_factory=lambda: CountType(10))
2423
]
2524
non_root_model_field: Annotated[
2625
PersonType | None,

tests/data/expected/main/openapi/pydantic_v2_default_object/Nested.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,4 @@ class Bar(BaseModel):
2323
for v in [{'text': 'abc', 'number': 123}, {'text': 'efg', 'number': 456}]
2424
]
2525
)
26-
nested_foo: Foo | None = Field(
27-
default_factory=lambda: Foo.model_validate('default foo')
28-
)
26+
nested_foo: Foo | None = Field(default_factory=lambda: Foo('default foo'))

tests/data/expected/main/openapi/referenced_default.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,4 @@ class ModelSettingB(RootModel[confloat(ge=0.0, le=10.0)]):
1313

1414
class Model(BaseModel):
1515
settingA: confloat(ge=0.0, le=10.0) | None = 5
16-
settingB: ModelSettingB | None = Field(
17-
default_factory=lambda: ModelSettingB.model_validate(5)
18-
)
16+
settingB: ModelSettingB | None = Field(default_factory=lambda: ModelSettingB(5))

tests/data/expected/main/openapi/referenced_default_use_annotated.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@ class ModelSettingB(RootModel[float]):
1616
class Model(BaseModel):
1717
settingA: Annotated[float | None, Field(ge=0.0, le=10.0)] = 5
1818
settingB: Annotated[
19-
ModelSettingB | None,
20-
Field(default_factory=lambda: ModelSettingB.model_validate(ModelSettingB(5))),
19+
ModelSettingB | None, Field(default_factory=lambda: ModelSettingB(5))
2120
]

0 commit comments

Comments
 (0)