Skip to content

Commit a07509f

Browse files
committed
Map discriminator mapping values with real enum values
1 parent e23d1c6 commit a07509f

File tree

1 file changed

+46
-35
lines changed
  • src/datamodel_code_generator/parser

1 file changed

+46
-35
lines changed

src/datamodel_code_generator/parser/base.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,17 +1521,33 @@ def _create_discriminator_data_type(
15211521
) -> DataType:
15221522
"""Create a data type for discriminator field, using enum literals if available."""
15231523
if enum_source:
1524-
enum_class_name = enum_source.reference.short_name
1525-
enum_member_literals: list[tuple[str, str]] = []
1526-
for value in discriminator_values:
1527-
member = enum_source.find_member(value)
1528-
if member and member.field.name:
1529-
enum_member_literals.append((enum_class_name, member.field.name))
1530-
else: # pragma: no cover
1531-
enum_member_literals.append((enum_class_name, str(value)))
1532-
data_type = self.data_type(enum_member_literals=enum_member_literals)
1533-
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
1534-
imports.append(Import.from_full_path(enum_source.name))
1524+
if self.use_enum_values_in_discriminator:
1525+
enum_class_name = enum_source.reference.short_name
1526+
enum_member_literals: list[tuple[str, DiscriminatorValue]] = []
1527+
for value in discriminator_values:
1528+
member = enum_source.find_member(value)
1529+
if member and member.field.name:
1530+
enum_member_literals.append((enum_class_name, member.field.name))
1531+
else: # pragma: no cover
1532+
enum_member_literals.append((enum_class_name, str(value)))
1533+
data_type = self.data_type(enum_member_literals=enum_member_literals)
1534+
if enum_source.module_path != discriminator_model.module_path: # pragma: no cover
1535+
imports.append(Import.from_full_path(enum_source.name))
1536+
else:
1537+
# According to OpenAPI specification, mapping discriminators are always string values.
1538+
# However, if the mapped object is an enum, we want to use the real enum value instead of
1539+
# the string value.
1540+
# See: https://swagger.io/specification/#options-for-mapping-values-to-schemas
1541+
# Fix: https://github.com/koxudaxi/datamodel-code-generator/issues/3073
1542+
for i, value in enumerate(discriminator_values):
1543+
if member := enum_source.find_member(value):
1544+
match member.field.default:
1545+
case str():
1546+
discriminator_values[i] = member.field.default.strip("'\"")
1547+
case _ if isinstance(member.field.default, DiscriminatorValue):
1548+
discriminator_values[i] = member.field.default
1549+
1550+
data_type = self.data_type(literals=discriminator_values)
15351551
else:
15361552
data_type = self.data_type(literals=discriminator_values)
15371553
return data_type
@@ -1637,28 +1653,27 @@ def get_discriminator_field_value(
16371653
raise RuntimeError(msg)
16381654

16391655
enum_from_base: Enum | None = None
1640-
if self.use_enum_values_in_discriminator:
1641-
for base_class in discriminator_model.base_classes:
1642-
if not base_class.reference or not base_class.reference.source: # pragma: no cover
1643-
continue
1644-
base_model = base_class.reference.source
1645-
if not isinstance( # pragma: no cover
1646-
base_model,
1647-
(
1648-
pydantic_model_v2.BaseModel,
1649-
dataclass_model.DataClass,
1650-
msgspec_model.Struct,
1651-
),
1652-
):
1656+
for base_class in discriminator_model.base_classes:
1657+
if not base_class.reference or not base_class.reference.source: # pragma: no cover
1658+
continue
1659+
base_model = base_class.reference.source
1660+
if not isinstance( # pragma: no cover
1661+
base_model,
1662+
(
1663+
pydantic_model_v2.BaseModel,
1664+
dataclass_model.DataClass,
1665+
msgspec_model.Struct,
1666+
),
1667+
):
1668+
continue
1669+
for base_field in base_model.fields: # pragma: no branch
1670+
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
16531671
continue
1654-
for base_field in base_model.fields: # pragma: no branch
1655-
if field_name not in {base_field.original_name, base_field.name}: # pragma: no cover
1656-
continue
1657-
enum_from_base = base_field.data_type.find_source(Enum)
1658-
if enum_from_base: # pragma: no branch
1659-
break
1672+
enum_from_base = base_field.data_type.find_source(Enum)
16601673
if enum_from_base: # pragma: no branch
16611674
break
1675+
if enum_from_base: # pragma: no branch
1676+
break
16621677

16631678
has_one_literal = False
16641679
for discriminator_field in discriminator_model.fields:
@@ -1690,11 +1705,7 @@ def get_discriminator_field_value(
16901705
discriminator_field.extras["is_classvar"] = True
16911706
break
16921707

1693-
enum_source: Enum | None = None
1694-
if self.use_enum_values_in_discriminator:
1695-
enum_source = ( # pragma: no cover
1696-
discriminator_field.data_type.find_source(Enum) or enum_from_base
1697-
)
1708+
enum_source = discriminator_field.data_type.find_source(Enum) or enum_from_base
16981709

16991710
for field_data_type in discriminator_field.data_type.all_data_types:
17001711
if field_data_type.reference: # pragma: no cover

0 commit comments

Comments
 (0)