Skip to content

Commit c858a99

Browse files
authored
Added support to handle nested Pydantic types (#640)
Summary - Fix `_get_element_type()` in the type engine to correctly resolve nested and recursive Pydantic types that were previously flattened to str - Add $ref resolution so nested models or enums inside containers (List, Dict) are resolved via $defs lookup instead of defaulting to str - Rewrite anyOf handling to support $ref variants (e.g. Optional[Inner]), fixing KeyError: 'type' crashes - Add recursive array/object branches so nested containers (List[List[int]], Dict[str, Dict[str, int]]) resolve to their correct Python types - Handle null JSON schema type, returning NoneType instead of str - Add example demonstrating all nested Pydantic type patterns running as parallel Flyte tasks Test plan - pytest `tests/flyte/type_engine/pydantic/test_nested_structs_in_pydantic.py` : 16 new tests covering - - Nested arrays: List[List[int]] schema resolution + roundtrip, List[List[Inner]] roundtrip - Nested dicts: List[Dict[str, int]] schema + roundtrip, Dict[str, Dict[str, int]] schema + roundtrip, List[Dict[str, Inner]] roundtrip, Dict[str, Dict[str, Inner]] roundtrip - anyOf with $ref: List[Optional[Inner]] schema resolution (no KeyError), List[Optional[str]] schema resolution, List[Optional[Inner]] roundtrip via guess_python_type - null type: {"type": "null"} resolves to NoneType - Enums in containers: List[ModelWithEnum] roundtrip, Dict[str, ModelWithEnum] roundtrip - Complex combined: model with List[List[Inner]], Dict[str, Inner], List[Dict[str, int]], List[ModelWithEnum], Optional[Inner] roundtrip with data and with None optional All previous tests pass --------- Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com>
1 parent 9a4eca1 commit c858a99

File tree

3 files changed

+496
-13
lines changed

3 files changed

+496
-13
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import asyncio
2+
from enum import Enum
3+
from typing import Dict, List, Optional
4+
5+
from pydantic import BaseModel
6+
7+
import flyte
8+
9+
env = flyte.TaskEnvironment(name="inputs_pydantic_nested_types")
10+
11+
12+
class Status(str, Enum):
13+
ACTIVE = "active"
14+
INACTIVE = "inactive"
15+
16+
17+
class Inner(BaseModel):
18+
name: str
19+
value: int
20+
21+
22+
class ModelWithEnum(BaseModel):
23+
label: str
24+
status: Status
25+
26+
27+
@env.task
28+
async def nested_lists(matrix: List[List[int]]) -> str:
29+
return f"Matrix {len(matrix)}x{len(matrix[0])}: {matrix}"
30+
31+
32+
@env.task
33+
async def list_of_dicts(records: List[Dict[str, int]]) -> str:
34+
return f"Records ({len(records)} entries): {records}"
35+
36+
37+
@env.task
38+
async def dict_of_dicts(nested_map: Dict[str, Dict[str, int]]) -> str:
39+
return f"Nested map keys: {list(nested_map.keys())}, values: {nested_map}"
40+
41+
42+
@env.task
43+
async def nested_models(items: List[List[Inner]]) -> str:
44+
names = [[m.name for m in row] for row in items]
45+
return f"Nested model names: {names}"
46+
47+
48+
@env.task
49+
async def dict_of_models(models: Dict[str, Inner]) -> str:
50+
return f"Dict of models: {list(models.keys())}"
51+
52+
53+
@env.task
54+
async def enum_in_models(jobs: List[ModelWithEnum]) -> str:
55+
return f"Jobs: {[(j.label, j.status.value) for j in jobs]}"
56+
57+
58+
@env.task
59+
async def optional_model(inner: Optional[Inner] = None) -> str:
60+
return f"Optional inner: {inner}"
61+
62+
63+
class ComplexNestedModel(BaseModel):
64+
nested_list: List[List[Inner]]
65+
dict_of_model_lists: Dict[str, List[Inner]]
66+
list_of_model_dicts: List[Dict[str, Inner]]
67+
enum_model_map: Dict[str, ModelWithEnum]
68+
list_of_dicts: List[Dict[str, int]]
69+
optional_inner: Optional[Inner] = None
70+
71+
72+
@env.task
73+
async def complex_nesting(data: ComplexNestedModel) -> str:
74+
result = f"Nested list rows: {len(data.nested_list)}"
75+
result += f"\nDict of model lists keys: {list(data.dict_of_model_lists.keys())}"
76+
result += f"\nList of model dicts count: {len(data.list_of_model_dicts)}"
77+
result += f"\nEnum model map: {[(k, v.status.value) for k, v in data.enum_model_map.items()]}"
78+
result += f"\nList of dicts: {data.list_of_dicts}"
79+
result += f"\nOptional inner: {data.optional_inner}"
80+
return result
81+
82+
83+
@env.task
84+
async def main() -> str:
85+
r1, r2, r3, r4, r5, r6, r7, r8, r9 = await asyncio.gather(
86+
nested_lists(matrix=[[1, 2], [3, 4]]),
87+
list_of_dicts(records=[{"a": 1, "b": 2}, {"c": 3}]),
88+
dict_of_dicts(nested_map={"outer": {"inner_key": 10}}),
89+
nested_models(items=[[Inner(name="a", value=1)], [Inner(name="b", value=2)]]),
90+
dict_of_models(models={"x": Inner(name="c", value=3)}),
91+
enum_in_models(
92+
jobs=[
93+
ModelWithEnum(label="job1", status=Status.ACTIVE),
94+
ModelWithEnum(label="job2", status=Status.INACTIVE),
95+
]
96+
),
97+
optional_model(inner=Inner(name="d", value=4)),
98+
optional_model(),
99+
complex_nesting(
100+
data=ComplexNestedModel(
101+
nested_list=[[Inner(name="a", value=1)], [Inner(name="b", value=2)]],
102+
dict_of_model_lists={"group1": [Inner(name="c", value=3), Inner(name="d", value=4)]},
103+
list_of_model_dicts=[{"x": Inner(name="e", value=5)}, {"y": Inner(name="f", value=6)}],
104+
enum_model_map={
105+
"j1": ModelWithEnum(label="job1", status=Status.ACTIVE),
106+
"j2": ModelWithEnum(label="job2", status=Status.INACTIVE),
107+
},
108+
list_of_dicts=[{"k1": 10, "k2": 20}],
109+
optional_inner=Inner(name="g", value=7),
110+
)
111+
),
112+
)
113+
114+
return f"{r1}\n{r2}\n{r3}\n{r4}\n{r5}\n{r6}\n{r7}\n{r8}\n{r9}"
115+
116+
117+
if __name__ == "__main__":
118+
flyte.init_from_config()
119+
120+
r = flyte.run(main)
121+
print(r.name)
122+
print(r.url)
123+
r.wait()

src/flyte/types/_type_engine.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
965965
property_type = property_val["type"]
966966
# Handle list
967967
if property_type == "array":
968-
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
968+
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"], schema)])) # type: ignore
969969
# Handle dataclass and dict
970970
elif property_type == "object":
971971
if property_val.get("anyOf"):
@@ -1004,7 +1004,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10041004
nested_types[property_key] = nested_class
10051005
elif property_val.get("additionalProperties"):
10061006
# For typing.Dict type
1007-
elem_type = _get_element_type(property_val["additionalProperties"])
1007+
elem_type = _get_element_type(property_val["additionalProperties"], schema)
10081008
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
10091009
elif property_val.get("title"):
10101010
# For nested dataclass
@@ -1047,7 +1047,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10471047
attribute_list.append([property_key, str]) # type: ignore
10481048
# Handle int, float, bool or str
10491049
else:
1050-
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
1050+
attribute_list.append([property_key, _get_element_type(property_val, schema)]) # type: ignore
10511051
return attribute_list, nested_types
10521052

10531053

@@ -2076,7 +2076,14 @@ def __init__(self, *args, **kwargs): # type: ignore[misc]
20762076
return cls
20772077

20782078

2079-
def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool]) -> Type:
2079+
# The value in a JSON schema doesn't always have to be a string, they can be dicts e.g. items, additionalProperties,
2080+
# anyOf, lists or bool. The old type hint was inaccurate.
2081+
# New parameter added for schema. `_get_element_type` needs to look up $defs when resolving $ref paths. Default
2082+
# - None, backward compatible.
2083+
def _get_element_type(
2084+
element_property: typing.Union[typing.Dict[str, typing.Any], bool],
2085+
schema: typing.Optional[typing.Dict[str, typing.Any]] = None,
2086+
) -> Type:
20802087
from flyte.io._dir import Dir
20812088
from flyte.io._file import File
20822089

@@ -2091,16 +2098,48 @@ def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool
20912098
return File
20922099
elif Dir.schema_match(element_property):
20932100
return Dir
2094-
element_type = (
2095-
[e_property["type"] for e_property in element_property["anyOf"]] # type: ignore
2096-
if element_property.get("anyOf")
2097-
else element_property["type"]
2098-
)
2099-
element_format = element_property["format"] if "format" in element_property else None
21002101

2101-
if isinstance(element_type, list):
2102-
# Element type of Optional[int] is [integer, None]
2103-
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore
2102+
# Handle $ref for nested models and enums
2103+
2104+
# Ensure that the element is actually a $ref and we have the entire schema to look up
2105+
if element_property.get("$ref") and schema is not None:
2106+
ref_name = element_property["$ref"].split("/")[-1]
2107+
defs = schema.get("$defs", schema.get("definitions", {}))
2108+
# Look up for ref_name in the defs defined in the schema
2109+
if ref_name in defs:
2110+
# Don't mutate the orignal schema
2111+
ref_schema = defs[ref_name].copy()
2112+
# Guard the nested enum elements inside containers
2113+
if ref_schema.get("enum"):
2114+
return str
2115+
# if defs not in the schema, they need to be propogated into the resolved schema
2116+
if "$defs" not in ref_schema and defs:
2117+
ref_schema["$defs"] = defs
2118+
# build a dataclass from the resolved schema
2119+
return convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
2120+
# default to str on failure. Shouldn't happen with valid pydantic schemas
2121+
return str
2122+
2123+
# Handle anyOf (e.g. Optional[int], Optional[Inner])
2124+
# Early return block replacing the previous list comprehension which would fail when an anyOf reference was a $ref
2125+
# (meaning no $type key).
2126+
if element_property.get("anyOf"):
2127+
# Separate non null variants. Note a $ref variant would have type None NOT null. A {"type": "null"} variant is
2128+
# filtered out.
2129+
variants = element_property["anyOf"]
2130+
non_null = [v for v in variants if v.get("type") != "null"]
2131+
# Detect if this is an Optional pattern here
2132+
has_null = len(non_null) < len(variants)
2133+
# This recurses on the first non-null variant which would handle the $ref, nested_arrays, nested_objects...
2134+
# anything. Wrap it in Optional if has_null.
2135+
if non_null:
2136+
inner_type = _get_element_type(non_null[0], schema)
2137+
return typing.Optional[inner_type] if has_null else inner_type # type: ignore
2138+
# return None if all types are None
2139+
return type(None)
2140+
2141+
element_type = element_property.get("type", "string")
2142+
element_format = element_property.get("format")
21042143

21052144
if element_type == "string":
21062145
return str
@@ -2113,6 +2152,16 @@ def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool
21132152
return int
21142153
else:
21152154
return float
2155+
# Recursively discover the types when an array or object element type is discovered
2156+
elif element_type == "array":
2157+
return typing.List[_get_element_type(element_property.get("items", {}), schema)] # type: ignore
2158+
elif element_type == "object":
2159+
if element_property.get("additionalProperties"):
2160+
return typing.Dict[str, _get_element_type(element_property["additionalProperties"], schema)] # type: ignore
2161+
return dict
2162+
# Corner case - practically useless but List[None] is a legal Python type
2163+
elif element_type == "null":
2164+
return type(None)
21162165
return str
21172166

21182167

0 commit comments

Comments
 (0)