Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions examples/basics/types/pydantic_nested_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import asyncio
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel

import flyte

env = flyte.TaskEnvironment(name="inputs_pydantic_nested_types")


class Status(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"


class Inner(BaseModel):
name: str
value: int


class ModelWithEnum(BaseModel):
label: str
status: Status


@env.task
async def nested_lists(matrix: List[List[int]]) -> str:
return f"Matrix {len(matrix)}x{len(matrix[0])}: {matrix}"


@env.task
async def list_of_dicts(records: List[Dict[str, int]]) -> str:
return f"Records ({len(records)} entries): {records}"


@env.task
async def dict_of_dicts(nested_map: Dict[str, Dict[str, int]]) -> str:
return f"Nested map keys: {list(nested_map.keys())}, values: {nested_map}"


@env.task
async def nested_models(items: List[List[Inner]]) -> str:
names = [[m.name for m in row] for row in items]
return f"Nested model names: {names}"


@env.task
async def dict_of_models(models: Dict[str, Inner]) -> str:
return f"Dict of models: {list(models.keys())}"


@env.task
async def enum_in_models(jobs: List[ModelWithEnum]) -> str:
return f"Jobs: {[(j.label, j.status.value) for j in jobs]}"


@env.task
async def optional_model(inner: Optional[Inner] = None) -> str:
return f"Optional inner: {inner}"


class ComplexNestedModel(BaseModel):
nested_list: List[List[Inner]]
dict_of_model_lists: Dict[str, List[Inner]]
list_of_model_dicts: List[Dict[str, Inner]]
enum_model_map: Dict[str, ModelWithEnum]
list_of_dicts: List[Dict[str, int]]
optional_inner: Optional[Inner] = None


@env.task
async def complex_nesting(data: ComplexNestedModel) -> str:
result = f"Nested list rows: {len(data.nested_list)}"
result += f"\nDict of model lists keys: {list(data.dict_of_model_lists.keys())}"
result += f"\nList of model dicts count: {len(data.list_of_model_dicts)}"
result += f"\nEnum model map: {[(k, v.status.value) for k, v in data.enum_model_map.items()]}"
result += f"\nList of dicts: {data.list_of_dicts}"
result += f"\nOptional inner: {data.optional_inner}"
return result


@env.task
async def main() -> str:
r1, r2, r3, r4, r5, r6, r7, r8, r9 = await asyncio.gather(
nested_lists(matrix=[[1, 2], [3, 4]]),
list_of_dicts(records=[{"a": 1, "b": 2}, {"c": 3}]),
dict_of_dicts(nested_map={"outer": {"inner_key": 10}}),
nested_models(items=[[Inner(name="a", value=1)], [Inner(name="b", value=2)]]),
dict_of_models(models={"x": Inner(name="c", value=3)}),
enum_in_models(
jobs=[
ModelWithEnum(label="job1", status=Status.ACTIVE),
ModelWithEnum(label="job2", status=Status.INACTIVE),
]
),
optional_model(inner=Inner(name="d", value=4)),
optional_model(),
complex_nesting(
data=ComplexNestedModel(
nested_list=[[Inner(name="a", value=1)], [Inner(name="b", value=2)]],
dict_of_model_lists={"group1": [Inner(name="c", value=3), Inner(name="d", value=4)]},
list_of_model_dicts=[{"x": Inner(name="e", value=5)}, {"y": Inner(name="f", value=6)}],
enum_model_map={
"j1": ModelWithEnum(label="job1", status=Status.ACTIVE),
"j2": ModelWithEnum(label="job2", status=Status.INACTIVE),
},
list_of_dicts=[{"k1": 10, "k2": 20}],
optional_inner=Inner(name="g", value=7),
)
),
)

return f"{r1}\n{r2}\n{r3}\n{r4}\n{r5}\n{r6}\n{r7}\n{r8}\n{r9}"


if __name__ == "__main__":
flyte.init_from_config()

r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
75 changes: 62 additions & 13 deletions src/flyte/types/_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
property_type = property_val["type"]
# Handle list
if property_type == "array":
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"], schema)])) # type: ignore
# Handle dataclass and dict
elif property_type == "object":
if property_val.get("anyOf"):
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
nested_types[property_key] = nested_class
elif property_val.get("additionalProperties"):
# For typing.Dict type
elem_type = _get_element_type(property_val["additionalProperties"])
elem_type = _get_element_type(property_val["additionalProperties"], schema)
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
elif property_val.get("title"):
# For nested dataclass
Expand Down Expand Up @@ -1047,7 +1047,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
attribute_list.append([property_key, str]) # type: ignore
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
attribute_list.append([property_key, _get_element_type(property_val, schema)]) # type: ignore
return attribute_list, nested_types


Expand Down Expand Up @@ -2076,7 +2076,14 @@ def __init__(self, *args, **kwargs): # type: ignore[misc]
return cls


def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool]) -> Type:
# The value in a JSON schema doesn't always have to be a string, they can be dicts e.g. items, additionalProperties,
# anyOf, lists or bool. The old type hint was inaccurate.
# New parameter added for schema. `_get_element_type` needs to look up $defs when resolving $ref paths. Default
# - None, backward compatible.
def _get_element_type(
element_property: typing.Union[typing.Dict[str, typing.Any], bool],
schema: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> Type:
from flyte.io._dir import Dir
from flyte.io._file import File

Expand All @@ -2091,16 +2098,48 @@ def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool
return File
elif Dir.schema_match(element_property):
return Dir
element_type = (
[e_property["type"] for e_property in element_property["anyOf"]] # type: ignore
if element_property.get("anyOf")
else element_property["type"]
)
element_format = element_property["format"] if "format" in element_property else None

if isinstance(element_type, list):
# Element type of Optional[int] is [integer, None]
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore
# Handle $ref for nested models and enums

# Ensure that the element is actually a $ref and we have the entire schema to look up
if element_property.get("$ref") and schema is not None:
ref_name = element_property["$ref"].split("/")[-1]
defs = schema.get("$defs", schema.get("definitions", {}))
# Look up for ref_name in the defs defined in the schema
if ref_name in defs:
# Don't mutate the orignal schema
ref_schema = defs[ref_name].copy()
# Guard the nested enum elements inside containers
if ref_schema.get("enum"):
return str
# if defs not in the schema, they need to be propogated into the resolved schema
if "$defs" not in ref_schema and defs:
ref_schema["$defs"] = defs
# build a dataclass from the resolved schema
return convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
# default to str on failure. Shouldn't happen with valid pydantic schemas
return str

# Handle anyOf (e.g. Optional[int], Optional[Inner])
# Early return block replacing the previous list comprehension which would fail when an anyOf reference was a $ref
# (meaning no $type key).
if element_property.get("anyOf"):
# Separate non null variants. Note a $ref variant would have type None NOT null. A {"type": "null"} variant is
# filtered out.
variants = element_property["anyOf"]
non_null = [v for v in variants if v.get("type") != "null"]
# Detect if this is an Optional pattern here
has_null = len(non_null) < len(variants)
# This recurses on the first non-null variant which would handle the $ref, nested_arrays, nested_objects...
# anything. Wrap it in Optional if has_null.
if non_null:
inner_type = _get_element_type(non_null[0], schema)
return typing.Optional[inner_type] if has_null else inner_type # type: ignore
# return None if all types are None
return type(None)

element_type = element_property.get("type", "string")
element_format = element_property.get("format")

if element_type == "string":
return str
Expand All @@ -2113,6 +2152,16 @@ def _get_element_type(element_property: typing.Union[typing.Dict[str, str], bool
return int
else:
return float
# Recursively discover the types when an array or object element type is discovered
elif element_type == "array":
return typing.List[_get_element_type(element_property.get("items", {}), schema)] # type: ignore
elif element_type == "object":
if element_property.get("additionalProperties"):
return typing.Dict[str, _get_element_type(element_property["additionalProperties"], schema)] # type: ignore
return dict
# Corner case - practically useless but List[None] is a legal Python type
elif element_type == "null":
return type(None)
return str


Expand Down
Loading
Loading