Skip to content

Commit c70ff8c

Browse files
committed
Fix discriminator with allOf without Literal type for Pydantic v2
1 parent 3e7b16b commit c70ff8c

File tree

7 files changed

+243
-3
lines changed

7 files changed

+243
-3
lines changed

src/datamodel_code_generator/parser/openapi.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,54 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType:
395395

396396
return super().get_data_type(obj)
397397

398+
def _normalize_discriminator_mapping_ref(self, mapping_value: str) -> str: # noqa: PLR6301
399+
"""Normalize a discriminator mapping value to a full $ref path.
400+
401+
Per OpenAPI spec, mapping values can be either:
402+
- Full refs: "#/components/schemas/Pet" or "./other.yaml#/components/schemas/Pet"
403+
- Short names: "Pet" or "Pet.V1" (relative to #/components/schemas/)
404+
- Relative paths: "schemas/Pet" or "./other.yaml"
405+
406+
Values containing "/" or "#" are treated as paths/refs and passed through.
407+
All other values (including those with dots like "Pet.V1") are treated as
408+
short schema names and normalized to full refs.
409+
410+
Note: Bare file references without path separators (e.g., "other.yaml") will be
411+
treated as schema names. Use "./other.yaml" format for file references.
412+
413+
Note: This could be a staticmethod, but @snooper_to_methods() decorator
414+
converts staticmethods to regular functions when pysnooper is installed.
415+
"""
416+
if "/" in mapping_value or "#" in mapping_value:
417+
return mapping_value
418+
return f"#/components/schemas/{mapping_value}"
419+
420+
def _normalize_discriminator(self, discriminator: dict[str, Any]) -> dict[str, Any]:
421+
"""Return a copy of the discriminator dict with normalized mapping refs."""
422+
result = discriminator.copy()
423+
mapping = discriminator.get("mapping")
424+
if mapping:
425+
result["mapping"] = {
426+
k: self._normalize_discriminator_mapping_ref(v) for k, v in mapping.items() if isinstance(v, str)
427+
}
428+
return result
429+
398430
def _get_discriminator_union_type(self, ref: str) -> DataType | None:
399-
"""Create a union type for discriminator subtypes if available."""
431+
"""Create a union type for discriminator subtypes if available.
432+
433+
First tries to use allOf subtypes. If none found, falls back to using
434+
the discriminator mapping to create the union type. This handles cases
435+
where schemas don't use allOf inheritance but have explicit discriminator mappings.
436+
"""
400437
subtypes = self._discriminator_subtypes.get(ref, [])
438+
if not subtypes:
439+
discriminator = self._discriminator_schemas.get(ref)
440+
if discriminator:
441+
mapping = discriminator.get("mapping", {})
442+
if mapping:
443+
subtypes = [
444+
self._normalize_discriminator_mapping_ref(v) for v in mapping.values() if isinstance(v, str)
445+
]
401446
if not subtypes:
402447
return None
403448
refs = map(self.model_resolver.add_ref, subtypes)
@@ -430,10 +475,11 @@ def parse_object_fields(
430475
and (discriminator := self._discriminator_schemas.get(field.ref))
431476
):
432477
new_field_type = self._get_discriminator_union_type(field.ref) or field_obj.data_type
478+
normalized_discriminator = self._normalize_discriminator(discriminator)
433479
field_obj = self.data_model_field_type(**{ # noqa: PLW2901
434480
**field_obj.__dict__,
435481
"data_type": new_field_type,
436-
"extras": {**field_obj.extras, "discriminator": discriminator},
482+
"extras": {**field_obj.extras, "discriminator": normalized_discriminator},
437483
})
438484
result_fields.append(field_obj)
439485

tests/data/expected/main/openapi/discriminator/allof_no_subtypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from typing import Literal
8+
79
from pydantic import BaseModel, Field
810

911

@@ -13,11 +15,13 @@ class BaseItem(BaseModel):
1315

1416
class FooItem(BaseModel):
1517
fooValue: str | None = None
18+
itemType: Literal['foo']
1619

1720

1821
class BarItem(BaseModel):
1922
barValue: int | None = None
23+
itemType: Literal['bar']
2024

2125

2226
class ItemContainer(BaseModel):
23-
item: BaseItem = Field(..., discriminator='itemType')
27+
item: FooItem | BarItem = Field(..., discriminator='itemType')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_no_mapping.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class BaseItem(BaseModel):
13+
itemType: str
14+
15+
16+
class FooItem(BaseItem):
17+
fooValue: str | None = None
18+
itemType: Literal['FooItem']
19+
20+
21+
class BarItem(BaseItem):
22+
barValue: int | None = None
23+
itemType: Literal['BarItem']
24+
25+
26+
class ItemContainer(BaseModel):
27+
item: FooItem | BarItem = Field(..., discriminator='itemType')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# generated by datamodel-codegen:
2+
# filename: discriminator_short_mapping_names.yaml
3+
# timestamp: 2019-07-26T00:00:00+00:00
4+
5+
from __future__ import annotations
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class BaseItem(BaseModel):
13+
itemType: str
14+
15+
16+
class FooItem(BaseModel):
17+
fooValue: str | None = None
18+
itemType: Literal['foo']
19+
20+
21+
class BarItem(BaseModel):
22+
barValue: int | None = None
23+
itemType: Literal['bar']
24+
25+
26+
class ItemContainer(BaseModel):
27+
item: FooItem | BarItem = Field(..., discriminator='itemType')
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
openapi: 3.1.0
2+
info:
3+
title: Test
4+
description: "Test API"
5+
version: 0.0.0
6+
paths:
7+
/item:
8+
get:
9+
responses:
10+
'200':
11+
description: "Item"
12+
content:
13+
application/json:
14+
schema:
15+
$ref: "#/components/schemas/ItemContainer"
16+
components:
17+
schemas:
18+
ItemContainer:
19+
type: object
20+
required:
21+
- item
22+
properties:
23+
item:
24+
$ref: "#/components/schemas/BaseItem"
25+
BaseItem:
26+
type: object
27+
required:
28+
- itemType
29+
properties:
30+
itemType:
31+
type: string
32+
discriminator:
33+
propertyName: itemType
34+
FooItem:
35+
allOf:
36+
- $ref: "#/components/schemas/BaseItem"
37+
- type: object
38+
properties:
39+
fooValue:
40+
type: string
41+
BarItem:
42+
allOf:
43+
- $ref: "#/components/schemas/BaseItem"
44+
- type: object
45+
properties:
46+
barValue:
47+
type: integer
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Tests discriminator with short mapping names (not full $ref paths)
2+
# Per OpenAPI spec, mapping values can be short names like "FooItem" instead of "#/components/schemas/FooItem"
3+
4+
openapi: 3.1.0
5+
info:
6+
title: Test
7+
description: "Test API"
8+
version: 0.0.0
9+
paths:
10+
/item:
11+
get:
12+
responses:
13+
'200':
14+
description: "Item"
15+
content:
16+
application/json:
17+
schema:
18+
$ref: "#/components/schemas/ItemContainer"
19+
components:
20+
schemas:
21+
ItemContainer:
22+
type: object
23+
required:
24+
- item
25+
properties:
26+
item:
27+
$ref: "#/components/schemas/BaseItem"
28+
BaseItem:
29+
type: object
30+
required:
31+
- itemType
32+
properties:
33+
itemType:
34+
type: string
35+
discriminator:
36+
propertyName: itemType
37+
mapping:
38+
# Short names without full $ref paths
39+
foo: FooItem
40+
bar: BarItem
41+
FooItem:
42+
type: object
43+
properties:
44+
fooValue:
45+
type: string
46+
BarItem:
47+
type: object
48+
properties:
49+
barValue:
50+
type: integer

tests/main/openapi/test_main_openapi.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,45 @@ def test_main_openapi_discriminator_allof_no_subtypes(output_file: Path) -> None
271271
)
272272

273273

274+
def test_main_openapi_discriminator_short_mapping_names(output_file: Path) -> None:
275+
"""Test OpenAPI generation with discriminator using short mapping names.
276+
277+
Per OpenAPI spec, mapping values can be short names like "FooItem" instead
278+
of full refs like "#/components/schemas/FooItem". This tests that short
279+
names are normalized correctly.
280+
"""
281+
run_main_and_assert(
282+
input_path=OPEN_API_DATA_PATH / "discriminator_short_mapping_names.yaml",
283+
output_path=output_file,
284+
input_file_type="openapi",
285+
assert_func=assert_file_content,
286+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "short_mapping_names.py",
287+
extra_args=[
288+
"--output-model-type",
289+
"pydantic_v2.BaseModel",
290+
],
291+
)
292+
293+
294+
def test_main_openapi_discriminator_no_mapping(output_file: Path) -> None:
295+
"""Test OpenAPI generation with discriminator without mapping.
296+
297+
This tests the case where a discriminator has only propertyName but no mapping.
298+
The subtypes are discovered via allOf inheritance.
299+
"""
300+
run_main_and_assert(
301+
input_path=OPEN_API_DATA_PATH / "discriminator_no_mapping.yaml",
302+
output_path=output_file,
303+
input_file_type="openapi",
304+
assert_func=assert_file_content,
305+
expected_file=EXPECTED_OPENAPI_PATH / "discriminator" / "no_mapping.py",
306+
extra_args=[
307+
"--output-model-type",
308+
"pydantic_v2.BaseModel",
309+
],
310+
)
311+
312+
274313
def test_main_openapi_allof_with_oneof_ref(output_file: Path) -> None:
275314
"""Test OpenAPI generation with allOf referencing a oneOf schema.
276315

0 commit comments

Comments
 (0)