|
16 | 16 | import typing |
17 | 17 | from abc import ABC, abstractmethod |
18 | 18 | from collections import OrderedDict |
19 | | -from functools import lru_cache |
| 19 | +from functools import lru_cache, reduce |
20 | 20 | from types import GenericAlias |
21 | 21 | from typing import Any, Dict, List, NamedTuple, Optional, Type, cast |
22 | 22 |
|
@@ -1089,56 +1089,78 @@ def assert_type(self, t: Type[enum.Enum], v: T): |
1089 | 1089 | raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") |
1090 | 1090 |
|
1091 | 1091 |
|
1092 | | -def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): |
| 1092 | +def _handle_json_schema_property( |
| 1093 | + property_key: str, |
| 1094 | + property_val: dict, |
| 1095 | +) -> typing.Tuple[str, typing.Any]: |
| 1096 | + """ |
| 1097 | + A helper to handle the properties of a JSON schema and returns their equivalent Flyte attribute name and type. |
| 1098 | + """ |
| 1099 | + |
| 1100 | + # Handle Optional[T] or Union[T1, T2, ...] at the top level for proper recursion |
| 1101 | + if property_val.get("anyOf"): |
| 1102 | + # Sanity check 'anyOf' is not empty |
| 1103 | + assert len(property_val["anyOf"]) > 0 |
| 1104 | + # Check that there are no nested Optional or Union types - no need to support that pattern |
| 1105 | + # as it would just add complexity without much benefit |
| 1106 | + # A few examples: Optional[Optional[T]] or Union[T1, T2, Union[T3, T4], etc...] |
| 1107 | + if any(item.get("anyOf") for item in property_val["anyOf"]): |
| 1108 | + raise ValueError( |
| 1109 | + f"The property with name {property_key} has a nested Optional or Union type, this is not allowed for dataclass JSON deserialization." |
| 1110 | + ) |
| 1111 | + attr_types = [] |
| 1112 | + for item in property_val["anyOf"]: |
| 1113 | + _, attr_type = _handle_json_schema_property(property_key, item) |
| 1114 | + attr_types.append(attr_type) |
| 1115 | + |
| 1116 | + # Gather all the types and return a Union[T1, T2, ...] |
| 1117 | + attr_union_type = reduce(lambda x, y: typing.Union[x, y], attr_types) |
| 1118 | + return (property_key, attr_union_type) # type: ignore |
| 1119 | + |
| 1120 | + # Handle enum |
| 1121 | + if property_val.get("enum"): |
| 1122 | + property_type = "enum" |
| 1123 | + else: |
| 1124 | + property_type = property_val["type"] |
| 1125 | + |
| 1126 | + # Handle list |
| 1127 | + if property_type == "array": |
| 1128 | + return (property_key, typing.List[_get_element_type(property_val["items"])]) # type: ignore |
| 1129 | + # Handle null types (i.e. None) |
| 1130 | + elif property_type == "null": |
| 1131 | + return (property_key, type(None)) # type: ignore |
| 1132 | + # Handle dataclass and dict |
| 1133 | + elif property_type == "object": |
| 1134 | + # NOTE: No need to handle optional dataclasses here (i.e. checking for property_val.get("anyOf")) |
| 1135 | + # those are handled in the top level of the function with recursion. |
| 1136 | + if property_val.get("additionalProperties"): |
| 1137 | + # For typing.Dict type |
| 1138 | + elem_type = _get_element_type(property_val["additionalProperties"]) |
| 1139 | + return (property_key, typing.Dict[str, elem_type]) # type: ignore |
| 1140 | + elif property_val.get("title"): |
| 1141 | + # For nested dataclass |
| 1142 | + sub_schema_name = property_val["title"] |
| 1143 | + return ( |
| 1144 | + property_key, |
| 1145 | + typing.cast(GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schema_name)), |
| 1146 | + ) |
| 1147 | + else: |
| 1148 | + # For untyped dict |
| 1149 | + return (property_key, dict) # type: ignore |
| 1150 | + elif property_type == "enum": |
| 1151 | + return (property_key, str) # type: ignore |
| 1152 | + # Handle None, int, float, bool or str |
| 1153 | + else: |
| 1154 | + return (property_key, _get_element_type(property_val)) # type: ignore |
| 1155 | + |
| 1156 | + |
| 1157 | +def generate_attribute_list_from_dataclass_json_mixin( |
| 1158 | + schema: dict, |
| 1159 | + schema_name: typing.Any, |
| 1160 | +): |
1093 | 1161 | attribute_list: typing.List[typing.Tuple[Any, Any]] = [] |
1094 | 1162 | for property_key, property_val in schema["properties"].items(): |
1095 | | - property_type = "" |
1096 | | - if property_val.get("anyOf"): |
1097 | | - property_type = property_val["anyOf"][0]["type"] |
1098 | | - elif property_val.get("enum"): |
1099 | | - property_type = "enum" |
1100 | | - else: |
1101 | | - property_type = property_val["type"] |
1102 | | - # Handle list |
1103 | | - if property_type == "array": |
1104 | | - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore |
1105 | | - # Handle dataclass and dict |
1106 | | - elif property_type == "object": |
1107 | | - if property_val.get("anyOf"): |
1108 | | - # For optional with dataclass |
1109 | | - sub_schemea = property_val["anyOf"][0] |
1110 | | - sub_schemea_name = sub_schemea["title"] |
1111 | | - attribute_list.append( |
1112 | | - ( |
1113 | | - property_key, |
1114 | | - typing.cast( |
1115 | | - GenericAlias, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name) |
1116 | | - ), |
1117 | | - ) |
1118 | | - ) |
1119 | | - elif property_val.get("additionalProperties"): |
1120 | | - # For typing.Dict type |
1121 | | - elem_type = _get_element_type(property_val["additionalProperties"]) |
1122 | | - attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore |
1123 | | - elif property_val.get("title"): |
1124 | | - # For nested dataclass |
1125 | | - sub_schemea_name = property_val["title"] |
1126 | | - attribute_list.append( |
1127 | | - ( |
1128 | | - property_key, |
1129 | | - typing.cast( |
1130 | | - GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name) |
1131 | | - ), |
1132 | | - ) |
1133 | | - ) |
1134 | | - else: |
1135 | | - # For untyped dict |
1136 | | - attribute_list.append((property_key, dict)) # type: ignore |
1137 | | - elif property_type == "enum": |
1138 | | - attribute_list.append([property_key, str]) # type: ignore |
1139 | | - # Handle int, float, bool or str |
1140 | | - else: |
1141 | | - attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore |
| 1163 | + attribute_list.append(_handle_json_schema_property(property_key, property_val)) |
1142 | 1164 | return attribute_list |
1143 | 1165 |
|
1144 | 1166 |
|
|
0 commit comments