Skip to content

Commit a6a7b04

Browse files
rpmcgintykoxudaxi
andauthored
Fix bug in handling of graphql empty list defaults (#2948)
* adding tests to highlight the graphql bug * fixing test to actually compare files (previously was not) * fix previous test that did not specify assert_func * Fix empty list default for GraphQL list fields * Fix GraphQL empty list default handling * Skip empty list default tests on black 22 --------- Co-authored-by: Koudai Aono <koxudaxi@gmail.com>
1 parent 838b2a0 commit a6a7b04

File tree

8 files changed

+140
-12
lines changed

8 files changed

+140
-12
lines changed

src/datamodel_code_generator/model/pydantic/base_model.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,22 @@ def _get_strict_field_constraint_value(self, constraint: str, value: Any) -> Any
126126
return value
127127
return int(value)
128128

129-
def _get_default_as_pydantic_model(self) -> str | None:
129+
def _get_default_as_pydantic_model(self) -> str | None: # noqa: PLR0911, PLR0912
130130
if isinstance(self.default, WrappedDefault):
131131
return f"lambda :{self.default!r}"
132+
if self.data_type.is_list and len(self.data_type.data_types) == 1:
133+
data_type_child = self.data_type.data_types[0]
134+
if (
135+
data_type_child.reference
136+
and isinstance(data_type_child.reference.source, BaseModelBase)
137+
and isinstance(self.default, list)
138+
):
139+
if not self.default:
140+
return STANDARD_LIST
141+
return ( # pragma: no cover
142+
f"lambda :[{data_type_child.alias or data_type_child.reference.source.class_name}."
143+
f"{self._PARSE_METHOD}(v) for v in {self.default!r}]"
144+
)
132145
for data_type in self.data_type.data_types or (self.data_type,):
133146
# TODO: Check nested data_types
134147
if data_type.is_dict:
@@ -220,7 +233,7 @@ def __str__(self) -> str: # noqa: PLR0912
220233
elif isinstance(discriminator, dict): # pragma: no cover
221234
data["discriminator"] = discriminator["propertyName"]
222235

223-
if self.required:
236+
if self.required and not self.has_default:
224237
default_factory = None
225238
elif self.default is not UNDEFINED and self.default is not None and "default_factory" not in data:
226239
default_factory = self._get_default_as_pydantic_model()
@@ -249,7 +262,7 @@ def __str__(self) -> str: # noqa: PLR0912
249262

250263
if self.use_annotated:
251264
field_arguments = self._process_annotated_field_arguments(field_arguments)
252-
elif self.required:
265+
elif self.required and not default_factory:
253266
field_arguments = ["...", *field_arguments]
254267
elif not default_factory:
255268
default_repr = repr_set_sorted(self.default) if isinstance(self.default, set) else repr(self.default)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# generated by datamodel-codegen:
2+
# filename: empty_list_default.graphql
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
type Boolean = bool
12+
"""
13+
The `Boolean` scalar type represents `true` or `false`.
14+
"""
15+
16+
17+
type String = str
18+
"""
19+
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
20+
"""
21+
22+
23+
class Container(BaseModel):
24+
name: String
25+
typename__: Literal['Container'] | None = Field('Container', alias='__typename')
26+
27+
28+
class PodSpec(BaseModel):
29+
container_list: list[Container] = Field(default_factory=list)
30+
container_list_or_none: list[Container | None] = Field(default_factory=list)
31+
container_or_none_list_or_none: list[Container | None] | None = Field(
32+
default_factory=list
33+
)
34+
typename__: Literal['PodSpec'] | None = Field('PodSpec', alias='__typename')
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# generated by datamodel-codegen:
2+
# filename: empty_list_default.graphql
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
type Boolean = bool
12+
"""
13+
The `Boolean` scalar type represents `true` or `false`.
14+
"""
15+
16+
17+
type String = str
18+
"""
19+
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
20+
"""
21+
22+
23+
class Container(BaseModel):
24+
name: String
25+
typename__: Literal['Container'] | None = Field('Container', alias='__typename')
26+
27+
28+
class PodSpec(BaseModel):
29+
container_list: list[Container] = Field(default_factory=list)
30+
container_list_or_none: list[Container | None] = Field(default_factory=list)
31+
container_or_none_list_or_none: list[Container | None] | None = Field(
32+
default_factory=list
33+
)
34+
typename__: Literal['PodSpec'] | None = Field('PodSpec', alias='__typename')

tests/data/expected/main/openapi/empty_list_default.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
from __future__ import annotations
66

7-
from typing import List, Optional
8-
97
from pydantic import BaseModel, Field
108

119

1210
class Container(BaseModel):
13-
name: Optional[str] = None
11+
name: str | None = None
1412

1513

1614
class PodSpec(BaseModel):
17-
containers: Optional[List[Container]] = Field(default_factory=list)
15+
containers: list[Container] | None = Field(default_factory=list)

tests/data/expected/main/openapi/pydantic_v2_empty_list_default.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
from __future__ import annotations
66

7-
from typing import List, Optional
8-
97
from pydantic import BaseModel, Field
108

119

1210
class Container(BaseModel):
13-
name: Optional[str] = None
11+
name: str | None = None
1412

1513

1614
class PodSpec(BaseModel):
17-
containers: Optional[List[Container]] = Field(default_factory=list)
15+
containers: list[Container] | None = Field(default_factory=list)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
input Container {
3+
name: String!
4+
}
5+
6+
input PodSpec {
7+
container_list: [Container!]! = []
8+
container_list_or_none: [Container]! = []
9+
container_or_none_list_or_none: [Container] = []
10+
}

tests/main/graphql/test_main_graphql.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import black
88
import pytest
99

10-
from tests.main.conftest import DEFAULT_VALUES_DATA_PATH, GRAPHQL_DATA_PATH, LEGACY_BLACK_SKIP, run_main_and_assert
10+
from tests.main.conftest import (
11+
DEFAULT_VALUES_DATA_PATH,
12+
EXPECTED_GRAPHQL_PATH,
13+
GRAPHQL_DATA_PATH,
14+
LEGACY_BLACK_SKIP,
15+
run_main_and_assert,
16+
)
1117
from tests.main.graphql.conftest import assert_file_content
1218

1319
if TYPE_CHECKING:
@@ -104,6 +110,40 @@ def test_main_use_default_kwarg(output_file: Path) -> None:
104110
)
105111

106112

113+
@pytest.mark.parametrize(
114+
("output_model", "expected_output"),
115+
[
116+
(
117+
"pydantic.BaseModel",
118+
"empty_list_default.py",
119+
),
120+
(
121+
"pydantic_v2.BaseModel",
122+
"pydantic_v2_empty_list_default.py",
123+
),
124+
],
125+
)
126+
@pytest.mark.skipif(
127+
black.__version__.split(".")[0] in {"19", "22"},
128+
reason="Installed black doesn't support Python 3.12 target version",
129+
)
130+
def test_main_graphql_empty_list_default(output_model: str, expected_output: str, output_file: Path) -> None:
131+
"""Test GraphQL generation with empty list default values."""
132+
run_main_and_assert(
133+
input_path=GRAPHQL_DATA_PATH / "empty_list_default.graphql",
134+
output_path=output_file,
135+
assert_func=assert_file_content,
136+
expected_file=EXPECTED_GRAPHQL_PATH / expected_output,
137+
input_file_type="graphql",
138+
extra_args=[
139+
"--output-model-type",
140+
output_model,
141+
"--target-python-version",
142+
"3.12",
143+
],
144+
)
145+
146+
107147
@pytest.mark.skipif(
108148
black.__version__.split(".")[0] == "19",
109149
reason="Installed black doesn't support the old style",

tests/main/openapi/test_main_openapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2808,6 +2808,7 @@ def test_main_openapi_empty_list_default(output_model: str, expected_output: str
28082808
input_path=OPEN_API_DATA_PATH / "empty_list_default.yaml",
28092809
output_path=output_file,
28102810
expected_file=EXPECTED_OPENAPI_PATH / expected_output,
2811+
assert_func=assert_file_content,
28112812
input_file_type="openapi",
28122813
extra_args=[
28132814
"--output-model-type",

0 commit comments

Comments
 (0)