Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/basics/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import enum
from enum import StrEnum
from typing import Optional

from pydantic import BaseModel

import flyte

env = flyte.TaskEnvironment(name="enums")


# -- Enum definitions ----------------------------------------------------------


class Color(str, enum.Enum):
RED = "red-value"
GREEN = "green-value"
BLUE = "blue-value"


class Size(StrEnum):
SMALL = "sm-value"
MEDIUM = "md-value"
LARGE = "lg-value"
EXTRA_LARGE = "xl-value"


class Priority(str, enum.Enum):
LOW = "low-value"
MEDIUM = "medium-value"
HIGH = "high-value"
CRITICAL = "critical-value"


# -- Pydantic models with enums -----------------------------------------------


class ShirtOrder(BaseModel):
color: Color
size: Size
quantity: int


class Address(BaseModel):
street: str
city: str
zip_code: str


class Customer(BaseModel):
name: str
priority: Priority
address: Address


class FullOrder(BaseModel):
customer: Customer
items: list[ShirtOrder]
notes: Optional[str] = None


# -- Tasks ---------------------------------------------------------------------


@env.task
async def standalone_enum_echo(color: Color, size: Size) -> str:
"""Standalone enums as direct task inputs."""
return f"color={color.name}({color.value}), size={size.name}({size.value})"


@env.task
async def simple_pydantic_enum(order: ShirtOrder) -> str:
"""Enum fields inside a flat Pydantic model."""
return f"{order.quantity}x {order.color.name} shirt, size {order.size.name}"


@env.task
async def nested_pydantic_enum(order: FullOrder) -> str:
"""Enums inside nested Pydantic models."""
items_desc = ", ".join(f"{item.quantity}x {item.color.name}-{item.size.name}" for item in order.items)
return f"Order for {order.customer.name} (priority={order.customer.priority.name}): {items_desc}"


@env.task
async def main() -> list[str]:
results = []

# 1. Standalone enums
r = await standalone_enum_echo(color=Color.RED, size=Size.LARGE)
results.append(r)

# 2. Simple pydantic with enums
r = await simple_pydantic_enum(order=ShirtOrder(color=Color.BLUE, size=Size.MEDIUM, quantity=3))
results.append(r)

# 3. Nested pydantic with enums
order = FullOrder(
customer=Customer(
name="Alice",
priority=Priority.HIGH,
address=Address(street="123 Main St", city="Springfield", zip_code="62704"),
),
items=[
ShirtOrder(color=Color.RED, size=Size.SMALL, quantity=2),
ShirtOrder(color=Color.GREEN, size=Size.EXTRA_LARGE, quantity=1),
],
notes="Gift wrap please",
)
r = await nested_pydantic_enum(order=order)
results.append(r)

for line in results:
print(line)

return results


if __name__ == "__main__":
flyte.init_from_config()
# run = flyte.run(main)
run = flyte.run(standalone_enum_echo, color=Color.RED, size=Size.LARGE)
print(run.url)
137 changes: 125 additions & 12 deletions src/flyte/types/_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from mashumaro.jsonschema.schema import Instance
from mashumaro.mixins.json import DataClassJSONMixin
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Annotated, get_args, get_origin

import flyte.storage as storage
Expand Down Expand Up @@ -429,12 +430,115 @@ async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T
raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")


def _unwrap_optional(tp: type) -> type:
"""Unwrap Optional[X] to X. Returns tp unchanged if not Optional."""
origin = get_origin(tp)
args = get_args(tp)
if origin is typing.Union:
non_none = [a for a in args if a is not type(None)]
if len(non_none) == 1:
return non_none[0]
return tp


def _convert_enum_field(value: typing.Any, field_type: type, *, to_names: bool) -> typing.Any:
"""Convert a value based on field type, handling enums, nested BaseModels, lists, and dicts.

When to_names=True (serialization): converts enum value strings to name strings.
When to_names=False (deserialization): converts enum name strings to enum instances.
"""
resolved = _unwrap_optional(field_type)

if value is None:
return None

# Direct enum field
if isinstance(resolved, type) and issubclass(resolved, enum.Enum):
if to_names:
# Serialization: value string → name string (e.g., "red" → "RED")
if isinstance(value, (str, int, float)):
try:
return resolved(value).name
except (ValueError, KeyError):
return value
else:
# Deserialization: name string → enum instance, with value fallback
if isinstance(value, str):
try:
return resolved[value] # Try name lookup first
except KeyError:
try:
return resolved(value) # Fall back to value lookup
except (ValueError, KeyError):
return value
return value

# Nested BaseModel
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
if isinstance(value, dict):
return _walk_enum_fields(value, resolved, to_names=to_names)
return value

origin = get_origin(resolved)
args = get_args(resolved)

# list[X]
if origin is list and args and isinstance(value, list):
return [_convert_enum_field(item, args[0], to_names=to_names) for item in value]

# dict[K, V]
if origin is dict and len(args) == 2 and isinstance(value, dict):
key_type, val_type = args
return {
_convert_enum_field(k, key_type, to_names=to_names): _convert_enum_field(v, val_type, to_names=to_names)
for k, v in value.items()
}

return value


def _walk_enum_fields(data: dict, model_type: Type[BaseModel], *, to_names: bool) -> dict:
"""Walk a dict and convert enum fields guided by the model's type hints.

When to_names=True: converts enum value strings to name strings (for serialization).
When to_names=False: converts enum name strings to enum instances (for deserialization).
"""
try:
hints = typing.get_type_hints(model_type)
except Exception:
return data

result = {}
for key, value in data.items():
field_type = hints.get(key)
if field_type is None:
result[key] = value
continue
result[key] = _convert_enum_field(value, field_type, to_names=to_names)
return result


class CustomPydanticJsonSchemaGenerator(GenerateJsonSchema):
"""Custom JSON schema generator that uses enum member names instead of values.

This ensures consistency with EnumTransformer.get_literal_type(), which uses
enum names (e.name) for standalone enum types.
"""

def enum_schema(self, schema):
result = super().enum_schema(schema)
enum_cls = schema.get("cls")
if enum_cls and issubclass(enum_cls, enum.Enum) and "enum" in result:
result["enum"] = [e.name for e in enum_cls]
return result


class PydanticTransformer(TypeTransformer[BaseModel]):
def __init__(self):
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)

def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
schema = t.model_json_schema()
schema = t.model_json_schema(schema_generator=CustomPydanticJsonSchemaGenerator)

meta_struct = struct_pb2.Struct()
meta_struct.update(
Expand Down Expand Up @@ -465,16 +569,15 @@ async def to_literal(

json_str = python_val.model_dump_json()
dict_obj = json.loads(json_str)
dict_obj = _walk_enum_fields(dict_obj, type(python_val), to_names=True)
msgpack_bytes = msgpack.dumps(dict_obj)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
json_str = json.dumps(dict_obj)
python_val = expected_python_type.model_validate_json(
json_data=json_str, strict=False, context={"deserialize": True}
)
dict_obj = _walk_enum_fields(dict_obj, expected_python_type, to_names=False)
python_val = expected_python_type.model_validate(dict_obj, strict=False, context={"deserialize": True})
return python_val
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
Expand All @@ -490,7 +593,9 @@ async def to_python_value(self, lv: Literal, expected_python_type: Type[BaseMode
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

json_str = _json_format.MessageToJson(lv.scalar.generic)
python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True})
dict_obj = json.loads(json_str)
dict_obj = _walk_enum_fields(dict_obj, expected_python_type, to_names=False)
python_val = expected_python_type.model_validate(dict_obj, strict=False, context={"deserialize": True})
return python_val


Expand All @@ -507,7 +612,7 @@ def get_schema(

try:
if issubclass(instance.type, BaseModel):
pydantic_schema = instance.type.model_json_schema()
pydantic_schema = instance.type.model_json_schema(schema_generator=CustomPydanticJsonSchemaGenerator)
return JSONSchema.from_dict(pydantic_schema)
except TypeError:
return None
Expand Down Expand Up @@ -866,13 +971,21 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

values = [v.value for v in t] # type: ignore
if not isinstance(values[0], str):
raise TypeTransformerFailedError("Only EnumTypes with value of string are supported")
return LiteralType(enum_type=types_pb2.EnumType(values=values))
raise TypeTransformerFailedError("Only EnumTypes with name of value are supported")
names = [v.name for v in t] # type: ignore
return LiteralType(enum_type=types_pb2.EnumType(values=names))

async def to_literal(self, python_val: enum.Enum, python_type: Type[T], expected: LiteralType) -> Literal:
if isinstance(python_val, str):
# this is the case when python Literals are used as enums
if python_val not in expected.enum_type.values:
if python_val.__getattribute__("name"):
if python_val.__getattribute__("name") not in expected.enum_type.values:
raise TypeTransformerFailedError(
f"Value {python_val.__getattribute__('name')} is not valid value, expected -"
f" {expected.enum_type.values}"
)
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.__getattribute__("name")))) # type: ignore
elif python_val not in expected.enum_type.values:
raise TypeTransformerFailedError(
f"Value {python_val} is not valid value, expected - {expected.enum_type.values}"
)
Expand All @@ -882,7 +995,7 @@ async def to_literal(self, python_val: enum.Enum, python_type: Type[T], expected
if type(python_val.value) is not str:
raise TypeTransformerFailedError("Only string-valued enums are supported")

return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.name))) # type: ignore

async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
Expand All @@ -893,7 +1006,7 @@ async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T
# This is the case when python Literal types are used as enums. The class name is always LiteralEnum an
# hardcoded in flyte.models
return lv.scalar.primitive.string_value
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore
return expected_python_type[lv.scalar.primitive.string_value] # type: ignore

def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
if literal_type.HasField("enum_type"):
Expand Down
Loading
Loading