Skip to content

Commit ef7c699

Browse files
committed
[python-fastapi] support oneOf the pydantic v2 way
* Support oneOf and anyOf schemas the pydantic v2 way by generating them as Unions. * Generate model constructor that forcefully sets the discriminator field to ensure it is included in the marshalled representation.
1 parent f57a1d5 commit ef7c699

File tree

21 files changed

+218
-340
lines changed

21 files changed

+218
-340
lines changed

modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/AbstractPythonCodegen.java

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,8 @@ public Map<String, ModelsMap> postProcessAllModels(Map<String, ModelsMap> objs)
823823
codegenModelMap.put(cm.classname, ModelUtils.getModelByName(entry.getKey(), objs));
824824
}
825825

826+
propagateDiscriminatorValuesToProperties(processed);
827+
826828
// create circular import
827829
for (String m : codegenModelMap.keySet()) {
828830
createImportMapOfSet(m, codegenModelMap);
@@ -1018,6 +1020,52 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
10181020
return objs;
10191021
}
10201022

1023+
private void propagateDiscriminatorValuesToProperties(Map<String, ModelsMap> objMap) {
1024+
HashMap<String, CodegenModel> modelMap = new HashMap<>();
1025+
for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
1026+
for (ModelMap m : entry.getValue().getModels()) {
1027+
modelMap.put("#/components/schemas/" + entry.getKey(), m.getModel());
1028+
}
1029+
}
1030+
1031+
for (Map.Entry<String, ModelsMap> entry : objMap.entrySet()) {
1032+
for (ModelMap m : entry.getValue().getModels()) {
1033+
CodegenModel model = m.getModel();
1034+
if (model.discriminator != null && !model.oneOf.isEmpty()) {
1035+
// Populate default, implicit discriminator values
1036+
for (String typeName : model.oneOf) {
1037+
ModelsMap obj = objMap.get(typeName);
1038+
if (obj == null) {
1039+
continue;
1040+
}
1041+
for (ModelMap m1 : obj.getModels()) {
1042+
for (CodegenProperty p : m1.getModel().vars) {
1043+
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
1044+
p.isDiscriminator = true;
1045+
p.discriminatorValue = typeName;
1046+
}
1047+
}
1048+
}
1049+
}
1050+
// Populate explicit discriminator values from mapping, overwriting default values
1051+
if (model.discriminator.getMapping() != null) {
1052+
for (Map.Entry<String, String> discrEntry : model.discriminator.getMapping().entrySet()) {
1053+
CodegenModel resolved = modelMap.get(discrEntry.getValue());
1054+
if (resolved != null) {
1055+
for (CodegenProperty p : resolved.vars) {
1056+
if (p.baseName.equals(model.discriminator.getPropertyBaseName())) {
1057+
p.isDiscriminator = true;
1058+
p.discriminatorValue = discrEntry.getKey();
1059+
}
1060+
}
1061+
}
1062+
}
1063+
}
1064+
}
1065+
}
1066+
}
1067+
}
1068+
10211069

10221070
/*
10231071
* Gets the pydantic type given a Codegen Property
@@ -2134,7 +2182,16 @@ private PythonType getType(CodegenProperty cp) {
21342182
}
21352183

21362184
private String finalizeType(CodegenProperty cp, PythonType pt) {
2137-
if (!cp.required || cp.isNullable) {
2185+
if (cp.isDiscriminator && cp.discriminatorValue != null) {
2186+
moduleImports.add("typing", "Literal");
2187+
PythonType literal = new PythonType("Literal");
2188+
String literalValue = '"'+escapeText(cp.discriminatorValue)+'"';
2189+
PythonType valueType = new PythonType(literalValue);
2190+
literal.addTypeParam(valueType);
2191+
literal.setDefaultValue(literalValue);
2192+
cp.setDefaultValue(literalValue);
2193+
pt = literal;
2194+
} else if (!cp.required || cp.isNullable) {
21382195
moduleImports.add("typing", "Optional");
21392196
PythonType opt = new PythonType("Optional");
21402197
opt.addTypeParam(pt);

modules/openapi-generator/src/main/resources/python-fastapi/model_anyof.mustache

Lines changed: 28 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -14,174 +14,56 @@ import re # noqa: F401
1414
{{/vendorExtensions.x-py-model-imports}}
1515
from typing import Union, Any, List, TYPE_CHECKING, Optional, Dict
1616
from typing_extensions import Literal
17-
from pydantic import StrictStr, Field
17+
from pydantic import StrictStr, Field, RootModel
1818
try:
1919
from typing import Self
2020
except ImportError:
2121
from typing_extensions import Self
2222

23-
{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS = [{{#anyOf}}"{{.}}"{{^-last}}, {{/-last}}{{/anyOf}}]
24-
25-
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}):
23+
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}RootModel{{/parent}}):
2624
"""
2725
{{{description}}}{{^description}}{{{classname}}}{{/description}}
2826
"""
2927

30-
{{#composedSchemas.anyOf}}
31-
# data type: {{{dataType}}}
32-
{{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
33-
{{/composedSchemas.anyOf}}
34-
if TYPE_CHECKING:
35-
actual_instance: Optional[Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}]] = None
36-
else:
37-
actual_instance: Any = None
38-
any_of_schemas: List[str] = Literal[{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ANY_OF_SCHEMAS]
28+
root: Union[{{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}] = None
3929

4030
model_config = {
4131
"validate_assignment": True,
4232
"protected_namespaces": (),
4333
}
44-
{{#discriminator}}
45-
46-
discriminator_value_class_map: Dict[str, str] = {
47-
{{#children}}
48-
'{{^vendorExtensions.x-discriminator-value}}{{name}}{{/vendorExtensions.x-discriminator-value}}{{#vendorExtensions.x-discriminator-value}}{{{vendorExtensions.x-discriminator-value}}}{{/vendorExtensions.x-discriminator-value}}': '{{{classname}}}'{{^-last}},{{/-last}}
49-
{{/children}}
50-
}
51-
{{/discriminator}}
52-
53-
def __init__(self, *args, **kwargs) -> None:
54-
if args:
55-
if len(args) > 1:
56-
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
57-
if kwargs:
58-
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
59-
super().__init__(actual_instance=args[0])
60-
else:
61-
super().__init__(**kwargs)
62-
63-
@field_validator('actual_instance')
64-
def actual_instance_must_validate_anyof(cls, v):
65-
{{#isNullable}}
66-
if v is None:
67-
return v
68-
69-
{{/isNullable}}
70-
instance = {{{classname}}}.model_construct()
71-
error_messages = []
72-
{{#composedSchemas.anyOf}}
73-
# validate data type: {{{dataType}}}
74-
{{#isContainer}}
75-
try:
76-
instance.{{vendorExtensions.x-py-name}} = v
77-
return v
78-
except (ValidationError, ValueError) as e:
79-
error_messages.append(str(e))
80-
{{/isContainer}}
81-
{{^isContainer}}
82-
{{#isPrimitiveType}}
83-
try:
84-
instance.{{vendorExtensions.x-py-name}} = v
85-
return v
86-
except (ValidationError, ValueError) as e:
87-
error_messages.append(str(e))
88-
{{/isPrimitiveType}}
89-
{{^isPrimitiveType}}
90-
if not isinstance(v, {{{dataType}}}):
91-
error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`")
92-
else:
93-
return v
94-
95-
{{/isPrimitiveType}}
96-
{{/isContainer}}
97-
{{/composedSchemas.anyOf}}
98-
if error_messages:
99-
# no match
100-
raise ValueError("No match found when setting the actual_instance in {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
101-
else:
102-
return v
103-
104-
@classmethod
105-
def from_dict(cls, obj: dict) -> Self:
106-
return cls.from_json(json.dumps(obj))
10734

108-
@classmethod
109-
def from_json(cls, json_str: str) -> Self:
110-
"""Returns the object represented by the json string"""
111-
instance = cls.model_construct()
112-
{{#isNullable}}
113-
if json_str is None:
114-
return instance
115-
116-
{{/isNullable}}
117-
error_messages = []
118-
{{#composedSchemas.anyOf}}
119-
{{#isContainer}}
120-
# deserialize data into {{{dataType}}}
121-
try:
122-
# validation
123-
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
124-
# assign value to actual_instance
125-
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
126-
return instance
127-
except (ValidationError, ValueError) as e:
128-
error_messages.append(str(e))
129-
{{/isContainer}}
130-
{{^isContainer}}
131-
{{#isPrimitiveType}}
132-
# deserialize data into {{{dataType}}}
133-
try:
134-
# validation
135-
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
136-
# assign value to actual_instance
137-
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
138-
return instance
139-
except (ValidationError, ValueError) as e:
140-
error_messages.append(str(e))
141-
{{/isPrimitiveType}}
142-
{{^isPrimitiveType}}
143-
# {{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
144-
try:
145-
instance.actual_instance = {{{dataType}}}.from_json(json_str)
146-
return instance
147-
except (ValidationError, ValueError) as e:
148-
error_messages.append(str(e))
149-
{{/isPrimitiveType}}
150-
{{/isContainer}}
151-
{{/composedSchemas.anyOf}}
152-
153-
if error_messages:
154-
# no match
155-
raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with anyOf schemas: {{#anyOf}}{{{.}}}{{^-last}}, {{/-last}}{{/anyOf}}. Details: " + ", ".join(error_messages))
156-
else:
157-
return instance
35+
def to_str(self) -> str:
36+
"""Returns the string representation of the model using alias"""
37+
return pprint.pformat(self.model_dump(by_alias=True))
15838

15939
def to_json(self) -> str:
160-
"""Returns the JSON representation of the actual instance"""
161-
if self.actual_instance is None:
162-
return "null"
40+
"""Returns the JSON representation of the model using alias"""
41+
return self.model_dump_json(by_alias=True, exclude_unset=True)
16342

164-
to_json = getattr(self.actual_instance, "to_json", None)
165-
if callable(to_json):
166-
return self.actual_instance.to_json()
43+
@classmethod
44+
def from_json(cls, json_str: str) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
45+
"""Create an instance of {{{classname}}} from a JSON string"""
46+
return cls.from_dict(json.loads(json_str))
47+
48+
def to_dict(self) -> Dict[str, Any]:
49+
"""Return the dictionary representation of the model using alias"""
50+
to_dict = getattr(self.root, "to_dict", None)
51+
if callable(to_dict):
52+
return self.model_dump(by_alias=True, exclude_unset=True)
16753
else:
168-
return json.dumps(self.actual_instance)
54+
# primitive type
55+
return self.root
16956

170-
def to_dict(self) -> Dict:
171-
"""Returns the dict representation of the actual instance"""
172-
if self.actual_instance is None:
173-
return "null"
57+
@classmethod
58+
def from_dict(cls, obj: Dict) -> {{^hasChildren}}Self{{/hasChildren}}{{#hasChildren}}{{#discriminator}}Union[{{#children}}Self{{^-last}}, {{/-last}}{{/children}}]{{/discriminator}}{{^discriminator}}Self{{/discriminator}}{{/hasChildren}}:
59+
"""Create an instance of {{{classname}}} from a dict"""
60+
if obj is None:
61+
return None
17462

175-
to_json = getattr(self.actual_instance, "to_json", None)
176-
if callable(to_json):
177-
return self.actual_instance.to_dict()
178-
else:
179-
# primitive type
180-
return self.actual_instance
63+
if not isinstance(obj, dict):
64+
return cls.model_validate(obj)
18165

182-
def to_str(self) -> str:
183-
"""Returns the string representation of the actual instance"""
184-
return pprint.pformat(self.model_dump())
66+
return cls.parse_obj(obj)
18567

18668
{{#vendorExtensions.x-py-postponed-model-imports.size}}
18769
{{#vendorExtensions.x-py-postponed-model-imports}}

modules/openapi-generator/src/main/resources/python-fastapi/model_generic.mustache

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}
8686
{{/isAdditionalPropertiesTrue}}
8787
}
8888

89+
def __init__(self, *a, **kw):
90+
super().__init__(*a, **kw)
91+
{{#vars}}
92+
{{#isDiscriminator}}
93+
self.{{name}} = self.{{name}}
94+
{{/isDiscriminator}}
95+
{{/vars}}
8996

9097
def to_str(self) -> str:
9198
"""Returns the string representation of the model using alias"""

0 commit comments

Comments
 (0)