Skip to content

Commit 2ea6160

Browse files
fix: Add recursive array/object handling in _get_element_type
Address review feedback: _get_element_type now handles nested containers (List[List[int]], List[Dict[str, int]]) by recursing into array items and object additionalProperties. Also adds List[Optional[int]] confirmation via tests. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 76597a5 commit 2ea6160

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

src/flyte/types/_type_engine.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,21 @@ def _get_element_type(
22082208
# Element type of Optional[int] is [integer, None]
22092209
return typing.Optional[_get_element_type({"type": element_type[0]}, schema)] # type: ignore
22102210

2211+
# Handle nested array (e.g., List[List[int]])
2212+
if element_type == "array":
2213+
inner_items = element_property.get("items", {})
2214+
return typing.List[_get_element_type(inner_items, schema)] # type: ignore
2215+
2216+
# Handle nested object / dict (e.g., List[Dict[str, int]])
2217+
if element_type == "object":
2218+
additional = element_property.get("additionalProperties")
2219+
if additional:
2220+
return typing.Dict[str, _get_element_type(additional, schema)] # type: ignore
2221+
# Nested dataclass-like object with a title
2222+
if element_property.get("title"):
2223+
return convert_mashumaro_json_schema_to_python_class(element_property, element_property["title"])
2224+
return dict
2225+
22112226
if element_type == "string":
22122227
return str
22132228
elif element_type == "integer":

tests/flyte/type_engine/pydantic/test_nested_lists_maps.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ class Profile(BaseModel):
3838
address: Optional[Tag] = None
3939

4040

41+
class NestedLists(BaseModel):
42+
matrix: List[List[int]]
43+
44+
45+
class NestedDicts(BaseModel):
46+
lookup: List[Dict[str, int]]
47+
48+
49+
class OptionalInts(BaseModel):
50+
values: List[Optional[int]]
51+
52+
4153
# -- Dataclass models --
4254

4355

@@ -261,3 +273,69 @@ async def test_pydantic_empty_dict_of_nested():
261273
pv = await TypeEngine.to_python_value(lv, Inventory)
262274
assert pv.items == []
263275
assert pv.metadata == {}
276+
277+
278+
# -- Tests: recursive/nested containers --
279+
280+
281+
@pytest.mark.asyncio
282+
async def test_pydantic_list_of_lists():
283+
"""List[List[int]] should roundtrip correctly through the type engine."""
284+
input_val = NestedLists(matrix=[[1, 2], [3, 4, 5]])
285+
lit = TypeEngine.to_literal_type(NestedLists)
286+
lv = await TypeEngine.to_literal(input_val, python_type=NestedLists, expected=lit)
287+
288+
assert lit
289+
assert lv
290+
291+
guessed = TypeEngine.guess_python_type(lit)
292+
assert guessed
293+
294+
v = guessed(matrix=[[1, 2], [3, 4, 5]])
295+
new_lv = await TypeEngine.to_literal(v, guessed, lit)
296+
assert new_lv == lv
297+
298+
pv = await TypeEngine.to_python_value(new_lv, NestedLists)
299+
assert pv == input_val
300+
301+
302+
@pytest.mark.asyncio
303+
async def test_pydantic_list_of_dicts():
304+
"""List[Dict[str, int]] should roundtrip correctly through the type engine."""
305+
input_val = NestedDicts(lookup=[{"a": 1, "b": 2}, {"c": 3}])
306+
lit = TypeEngine.to_literal_type(NestedDicts)
307+
lv = await TypeEngine.to_literal(input_val, python_type=NestedDicts, expected=lit)
308+
309+
assert lit
310+
assert lv
311+
312+
guessed = TypeEngine.guess_python_type(lit)
313+
assert guessed
314+
315+
v = guessed(lookup=[{"a": 1, "b": 2}, {"c": 3}])
316+
new_lv = await TypeEngine.to_literal(v, guessed, lit)
317+
assert new_lv == lv
318+
319+
pv = await TypeEngine.to_python_value(new_lv, NestedDicts)
320+
assert pv == input_val
321+
322+
323+
@pytest.mark.asyncio
324+
async def test_pydantic_list_of_optional_int():
325+
"""List[Optional[int]] should roundtrip correctly through the type engine."""
326+
input_val = OptionalInts(values=[1, None, 3, None])
327+
lit = TypeEngine.to_literal_type(OptionalInts)
328+
lv = await TypeEngine.to_literal(input_val, python_type=OptionalInts, expected=lit)
329+
330+
assert lit
331+
assert lv
332+
333+
guessed = TypeEngine.guess_python_type(lit)
334+
assert guessed
335+
336+
v = guessed(values=[1, None, 3, None])
337+
new_lv = await TypeEngine.to_literal(v, guessed, lit)
338+
assert new_lv == lv
339+
340+
pv = await TypeEngine.to_python_value(new_lv, OptionalInts)
341+
assert pv == input_val

0 commit comments

Comments
 (0)