@@ -922,11 +922,19 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
922922 from flyte .io ._file import File
923923
924924 attribute_list : typing .List [typing .Tuple [Any , Any ]] = []
925- nested_types : typing .Dict [str , type ] = {} # Track nested model types for conversion
925+ # Track nested model types for conversion: (container_kind, nested_class)
926+ # container_kind: "direct" for field: Model, "list" for field: list[Model], "dict" for field: dict[str, Model]
927+ nested_types : typing .Dict [str , typing .Tuple [str , type ]] = {}
926928
927929 # Use 'required' field to preserve property order, as protobuf Struct doesn't preserve dict order
930+ # Also include optional properties (those not in 'required') so they are not skipped
928931 properties = schema ["properties" ]
929- property_order = schema .get ("required" , list (properties .keys ()))
932+ required = schema .get ("required" , list (properties .keys ()))
933+ property_order = list (required )
934+ for key in properties :
935+ if key not in property_order :
936+ property_order .append (key )
937+ defs = schema .get ("$defs" , schema .get ("definitions" , {}))
930938
931939 for property_key in property_order :
932940 property_val = properties [property_key ]
@@ -935,8 +943,6 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
935943 ref_path = property_val ["$ref" ]
936944 # Extract the definition name from the $ref path (e.g., "#/$defs/MyNestedModel" -> "MyNestedModel")
937945 ref_name = ref_path .split ("/" )[- 1 ]
938- # Get the referenced schema from $defs (or definitions for older schemas)
939- defs = schema .get ("$defs" , schema .get ("definitions" , {}))
940946 if ref_name in defs :
941947 ref_schema = defs [ref_name ].copy ()
942948 # Check if the $ref points to an enum definition (no properties)
@@ -954,18 +960,52 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
954960 )
955961 )
956962 # Track this as a nested type that needs dict-to-object conversion
957- nested_types [property_key ] = nested_class
963+ nested_types [property_key ] = ( "direct" , nested_class )
958964 continue
959965
960966 if property_val .get ("anyOf" ):
961- property_type = property_val ["anyOf" ][0 ]["type" ]
967+ first_option = property_val ["anyOf" ][0 ]
968+ if first_option .get ("$ref" ):
969+ # Handle Optional[NestedModel] - anyOf with $ref and null
970+ ref_path = first_option ["$ref" ]
971+ ref_name = ref_path .split ("/" )[- 1 ]
972+ if ref_name in defs :
973+ ref_schema = defs [ref_name ].copy ()
974+ if ref_schema .get ("enum" ):
975+ attribute_list .append ((property_key , typing .Optional [str ]))
976+ else :
977+ if "$defs" not in ref_schema and defs :
978+ ref_schema ["$defs" ] = defs
979+ nested_class = convert_mashumaro_json_schema_to_python_class (ref_schema , ref_name )
980+ attribute_list .append ((property_key , typing .Optional [typing .cast (GenericAlias , nested_class )]))
981+ nested_types [property_key ] = ("direct" , nested_class )
982+ continue
983+ property_type = first_option ["type" ]
962984 elif property_val .get ("enum" ):
963985 property_type = "enum"
964986 else :
965987 property_type = property_val ["type" ]
966988 # Handle list
967989 if property_type == "array" :
968- attribute_list .append ((property_key , typing .List [_get_element_type (property_val ["items" ])])) # type: ignore
990+ items = property_val ["items" ]
991+ if isinstance (items , dict ) and items .get ("$ref" ):
992+ # Handle list[NestedModel]
993+ ref_path = items ["$ref" ]
994+ ref_name = ref_path .split ("/" )[- 1 ]
995+ if ref_name in defs :
996+ ref_schema = defs [ref_name ].copy ()
997+ if ref_schema .get ("enum" ):
998+ attribute_list .append ((property_key , typing .List [str ])) # type: ignore
999+ else :
1000+ if "$defs" not in ref_schema and defs :
1001+ ref_schema ["$defs" ] = defs
1002+ nested_class = convert_mashumaro_json_schema_to_python_class (ref_schema , ref_name )
1003+ attribute_list .append ((property_key , typing .List [typing .cast (GenericAlias , nested_class )])) # type: ignore
1004+ nested_types [property_key ] = ("list" , nested_class )
1005+ else :
1006+ attribute_list .append ((property_key , typing .List [_get_element_type (items )])) # type: ignore
1007+ else :
1008+ attribute_list .append ((property_key , typing .List [_get_element_type (items )])) # type: ignore
9691009 # Handle dataclass and dict
9701010 elif property_type == "object" :
9711011 if property_val .get ("anyOf" ):
@@ -1001,11 +1041,31 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10011041 typing .cast (GenericAlias , nested_class ),
10021042 )
10031043 )
1004- nested_types [property_key ] = nested_class
1044+ nested_types [property_key ] = ( "direct" , nested_class )
10051045 elif property_val .get ("additionalProperties" ):
10061046 # For typing.Dict type
1007- elem_type = _get_element_type (property_val ["additionalProperties" ])
1008- attribute_list .append ((property_key , typing .Dict [str , elem_type ])) # type: ignore
1047+ additional = property_val ["additionalProperties" ]
1048+ if isinstance (additional , dict ) and additional .get ("$ref" ):
1049+ # Handle dict[str, NestedModel]
1050+ ref_path = additional ["$ref" ]
1051+ ref_name = ref_path .split ("/" )[- 1 ]
1052+ if ref_name in defs :
1053+ ref_schema = defs [ref_name ].copy ()
1054+ if ref_schema .get ("enum" ):
1055+ attribute_list .append ((property_key , typing .Dict [str , str ])) # type: ignore
1056+ else :
1057+ if "$defs" not in ref_schema and defs :
1058+ ref_schema ["$defs" ] = defs
1059+ nested_class = convert_mashumaro_json_schema_to_python_class (ref_schema , ref_name )
1060+ casted = typing .cast (GenericAlias , nested_class )
1061+ attribute_list .append ((property_key , typing .Dict [str , casted ])) # type: ignore
1062+ nested_types [property_key ] = ("dict" , nested_class )
1063+ else :
1064+ elem_type = _get_element_type (additional )
1065+ attribute_list .append ((property_key , typing .Dict [str , elem_type ])) # type: ignore
1066+ else :
1067+ elem_type = _get_element_type (additional )
1068+ attribute_list .append ((property_key , typing .Dict [str , elem_type ])) # type: ignore
10091069 elif property_val .get ("title" ):
10101070 # For nested dataclass
10111071 sub_schemea_name = property_val ["title" ]
@@ -1039,7 +1099,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
10391099 typing .cast (GenericAlias , nested_class ),
10401100 )
10411101 )
1042- nested_types [property_key ] = nested_class
1102+ nested_types [property_key ] = ( "direct" , nested_class )
10431103 else :
10441104 # For untyped dict
10451105 attribute_list .append ((property_key , dict )) # type: ignore
@@ -2064,11 +2124,17 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
20642124
20652125 def __init__ (self , * args , ** kwargs ): # type: ignore[misc]
20662126 # Convert dict values to nested types before calling original __init__
2067- for field_name , field_type in nested_types .items ():
2127+ for field_name , ( kind , field_type ) in nested_types .items ():
20682128 if field_name in kwargs :
20692129 value = kwargs [field_name ]
2070- if isinstance (value , dict ):
2130+ if kind == "direct" and isinstance (value , dict ):
20712131 kwargs [field_name ] = field_type (** value )
2132+ elif kind == "list" and isinstance (value , list ):
2133+ kwargs [field_name ] = [field_type (** item ) if isinstance (item , dict ) else item for item in value ]
2134+ elif kind == "dict" and isinstance (value , dict ):
2135+ kwargs [field_name ] = {
2136+ k : field_type (** v ) if isinstance (v , dict ) else v for k , v in value .items ()
2137+ }
20722138 original_init (self , * args , ** kwargs )
20732139
20742140 cls .__init__ = __init__ # type: ignore[method-assign, misc]
0 commit comments