Skip to content

Commit 2ebaea8

Browse files
committed
fix: Handle list[Model], dict[str, Model] and Optional[Model] in nested type detection
The nested schema type detection from PR #426 did not resolve $ref for list items, dict additionalProperties, or anyOf (Optional) fields. This caused KeyError or incorrect type inference when using Pydantic or dataclass models with fields like list[NestedModel], dict[str, NestedModel], or Optional[NestedModel]. Changes: - Resolve $ref in array items for list[NestedModel] support - Resolve $ref in additionalProperties for dict[str, NestedModel] support - Resolve $ref in anyOf for Optional[NestedModel] support - Include optional properties in property_order so fields with defaults are not skipped during schema traversal - Extend __init__ wrapper to convert list[dict] and dict[str, dict] to their respective nested types Ref: flyteorg/flyte#6887 Signed-off-by: André Ahlert <andre@aex.partners>
1 parent 28396f2 commit 2ebaea8

File tree

2 files changed

+332
-13
lines changed

2 files changed

+332
-13
lines changed

src/flyte/types/_type_engine.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,19 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
922922
from flyte.io._file import File
923923

924924
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
925-
nested_types: typing.Dict[str, type] = {} # Track nested model types for conversion
925+
# Track nested model types for conversion: (container_kind, nested_class)
926+
# container_kind: "direct" for field: Model, "list" for field: list[Model], "dict" for field: dict[str, Model]
927+
nested_types: typing.Dict[str, typing.Tuple[str, type]] = {}
926928

927929
# Use 'required' field to preserve property order, as protobuf Struct doesn't preserve dict order
930+
# Also include optional properties (those not in 'required') so they are not skipped
928931
properties = schema["properties"]
929-
property_order = schema.get("required", list(properties.keys()))
932+
required = schema.get("required", list(properties.keys()))
933+
property_order = list(required)
934+
for key in properties:
935+
if key not in property_order:
936+
property_order.append(key)
937+
defs = schema.get("$defs", schema.get("definitions", {}))
930938

931939
for property_key in property_order:
932940
property_val = properties[property_key]
@@ -935,8 +943,6 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
935943
ref_path = property_val["$ref"]
936944
# Extract the definition name from the $ref path (e.g., "#/$defs/MyNestedModel" -> "MyNestedModel")
937945
ref_name = ref_path.split("/")[-1]
938-
# Get the referenced schema from $defs (or definitions for older schemas)
939-
defs = schema.get("$defs", schema.get("definitions", {}))
940946
if ref_name in defs:
941947
ref_schema = defs[ref_name].copy()
942948
# Check if the $ref points to an enum definition (no properties)
@@ -954,18 +960,52 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
954960
)
955961
)
956962
# Track this as a nested type that needs dict-to-object conversion
957-
nested_types[property_key] = nested_class
963+
nested_types[property_key] = ("direct", nested_class)
958964
continue
959965

960966
if property_val.get("anyOf"):
961-
property_type = property_val["anyOf"][0]["type"]
967+
first_option = property_val["anyOf"][0]
968+
if first_option.get("$ref"):
969+
# Handle Optional[NestedModel] - anyOf with $ref and null
970+
ref_path = first_option["$ref"]
971+
ref_name = ref_path.split("/")[-1]
972+
if ref_name in defs:
973+
ref_schema = defs[ref_name].copy()
974+
if ref_schema.get("enum"):
975+
attribute_list.append((property_key, typing.Optional[str]))
976+
else:
977+
if "$defs" not in ref_schema and defs:
978+
ref_schema["$defs"] = defs
979+
nested_class = convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
980+
attribute_list.append((property_key, typing.Optional[typing.cast(GenericAlias, nested_class)]))
981+
nested_types[property_key] = ("direct", nested_class)
982+
continue
983+
property_type = first_option["type"]
962984
elif property_val.get("enum"):
963985
property_type = "enum"
964986
else:
965987
property_type = property_val["type"]
966988
# Handle list
967989
if property_type == "array":
968-
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
990+
items = property_val["items"]
991+
if isinstance(items, dict) and items.get("$ref"):
992+
# Handle list[NestedModel]
993+
ref_path = items["$ref"]
994+
ref_name = ref_path.split("/")[-1]
995+
if ref_name in defs:
996+
ref_schema = defs[ref_name].copy()
997+
if ref_schema.get("enum"):
998+
attribute_list.append((property_key, typing.List[str])) # type: ignore
999+
else:
1000+
if "$defs" not in ref_schema and defs:
1001+
ref_schema["$defs"] = defs
1002+
nested_class = convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
1003+
attribute_list.append((property_key, typing.List[typing.cast(GenericAlias, nested_class)])) # type: ignore
1004+
nested_types[property_key] = ("list", nested_class)
1005+
else:
1006+
attribute_list.append((property_key, typing.List[_get_element_type(items)])) # type: ignore
1007+
else:
1008+
attribute_list.append((property_key, typing.List[_get_element_type(items)])) # type: ignore
9691009
# Handle dataclass and dict
9701010
elif property_type == "object":
9711011
if property_val.get("anyOf"):
@@ -1001,11 +1041,31 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10011041
typing.cast(GenericAlias, nested_class),
10021042
)
10031043
)
1004-
nested_types[property_key] = nested_class
1044+
nested_types[property_key] = ("direct", nested_class)
10051045
elif property_val.get("additionalProperties"):
10061046
# For typing.Dict type
1007-
elem_type = _get_element_type(property_val["additionalProperties"])
1008-
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
1047+
additional = property_val["additionalProperties"]
1048+
if isinstance(additional, dict) and additional.get("$ref"):
1049+
# Handle dict[str, NestedModel]
1050+
ref_path = additional["$ref"]
1051+
ref_name = ref_path.split("/")[-1]
1052+
if ref_name in defs:
1053+
ref_schema = defs[ref_name].copy()
1054+
if ref_schema.get("enum"):
1055+
attribute_list.append((property_key, typing.Dict[str, str])) # type: ignore
1056+
else:
1057+
if "$defs" not in ref_schema and defs:
1058+
ref_schema["$defs"] = defs
1059+
nested_class = convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
1060+
casted = typing.cast(GenericAlias, nested_class)
1061+
attribute_list.append((property_key, typing.Dict[str, casted])) # type: ignore
1062+
nested_types[property_key] = ("dict", nested_class)
1063+
else:
1064+
elem_type = _get_element_type(additional)
1065+
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
1066+
else:
1067+
elem_type = _get_element_type(additional)
1068+
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
10091069
elif property_val.get("title"):
10101070
# For nested dataclass
10111071
sub_schemea_name = property_val["title"]
@@ -1039,7 +1099,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10391099
typing.cast(GenericAlias, nested_class),
10401100
)
10411101
)
1042-
nested_types[property_key] = nested_class
1102+
nested_types[property_key] = ("direct", nested_class)
10431103
else:
10441104
# For untyped dict
10451105
attribute_list.append((property_key, dict)) # type: ignore
@@ -2064,11 +2124,17 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
20642124

20652125
def __init__(self, *args, **kwargs): # type: ignore[misc]
20662126
# Convert dict values to nested types before calling original __init__
2067-
for field_name, field_type in nested_types.items():
2127+
for field_name, (kind, field_type) in nested_types.items():
20682128
if field_name in kwargs:
20692129
value = kwargs[field_name]
2070-
if isinstance(value, dict):
2130+
if kind == "direct" and isinstance(value, dict):
20712131
kwargs[field_name] = field_type(**value)
2132+
elif kind == "list" and isinstance(value, list):
2133+
kwargs[field_name] = [field_type(**item) if isinstance(item, dict) else item for item in value]
2134+
elif kind == "dict" and isinstance(value, dict):
2135+
kwargs[field_name] = {
2136+
k: field_type(**v) if isinstance(v, dict) else v for k, v in value.items()
2137+
}
20722138
original_init(self, *args, **kwargs)
20732139

20742140
cls.__init__ = __init__ # type: ignore[method-assign, misc]

0 commit comments

Comments
 (0)