From 2630515dea05eafb422767c1e40c00ac4105f176 Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Thu, 24 Jul 2025 14:50:37 -0700 Subject: [PATCH 1/9] [Core feature] Add support Literal Transformer Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 86 ++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 58ba0b8556..32e065f396 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -19,6 +19,7 @@ from functools import lru_cache, reduce from types import GenericAlias from typing import Any, Dict, List, NamedTuple, Optional, Type, cast +from typing import Literal as TypingLiteral import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json @@ -1087,6 +1088,85 @@ def assert_type(self, t: Type[enum.Enum], v: T): raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") +class LiteralTypeTransformer(TypeTransformer[TypingLiteral]): + def __init__(self): + super().__init__("LiteralTypeTransformer", TypingLiteral) + + def get_literal_type(self, t: Type) -> LiteralType: + args = get_args(t) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + if base_type == str: + return LiteralType(simple=SimpleType.STRING) + elif base_type == int: + return LiteralType(simple=SimpleType.INTEGER) + elif base_type == float: + return LiteralType(simple=SimpleType.FLOAT) + elif base_type == bool: + return LiteralType(simple=SimpleType.BOOLEAN) + elif base_type == datetime.datetime: + return LiteralType(simple=SimpleType.DATETIME) + elif base_type == datetime.timedelta: + return LiteralType(simple=SimpleType.DURATION) + else: + raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}") + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + if expected.simple == SimpleType.STRING: + return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) + elif expected.simple == SimpleType.INTEGER: + return Literal(scalar=Scalar(primitive=Primitive(integer=python_val))) + elif expected.simple == SimpleType.FLOAT: + return Literal(scalar=Scalar(primitive=Primitive(float_value=python_val))) + elif expected.simple == SimpleType.BOOLEAN: + return Literal(scalar=Scalar(primitive=Primitive(boolean=python_val))) + elif expected.simple == SimpleType.DATETIME: + return Literal(scalar=Scalar(primitive=Primitive(datetime=python_val))) + elif expected.simple == SimpleType.DURATION: + return Literal(scalar=Scalar(primitive=Primitive(duration=python_val))) + else: + raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}") + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + if lv.scalar.primitive.string_value is not None: + return lv.scalar.primitive.string_value + elif lv.scalar.primitive.integer is not None: + return lv.scalar.primitive.integer + elif lv.scalar.primitive.float_value is not None: + return lv.scalar.primitive.float_value + elif lv.scalar.primitive.boolean is not None: + return lv.scalar.primitive.boolean + elif lv.scalar.primitive.datetime is not None: + return lv.scalar.primitive.datetime + elif lv.scalar.primitive.duration is not None: + return lv.scalar.primitive.duration + else: + raise TypeTransformerFailedError("Unsupported Literal value") + + def guess_python_type(self, literal_type: LiteralType): + if literal_type.simple == SimpleType.STRING: + return str + elif literal_type.simple == SimpleType.INTEGER: + return int + elif literal_type.simple == SimpleType.FLOAT: + return float + elif literal_type.simple == SimpleType.BOOLEAN: + return bool + elif literal_type.simple == SimpleType.DATETIME: + return datetime.datetime + elif literal_type.simple == SimpleType.DURATION: + return datetime.timedelta + else: + raise TypeTransformerFailedError(f"LiteralTypeTransformer cannot reverse {literal_type}") + + def _handle_json_schema_property( property_key: str, property_val: dict, @@ -1173,6 +1253,7 @@ class TypeEngine(typing.Generic[T]): _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore + _LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer() lazy_import_lock = threading.Lock() @classmethod @@ -1222,6 +1303,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]: # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. return cls._ENUM_TRANSFORMER + if get_origin(python_type) == TypingLiteral: + return cls._LITERAL_TYPE_TRANSFORMER + if hasattr(python_type, "__origin__"): # If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON] # or List[int] has been specifically registered; we should check for the entire type. @@ -1253,6 +1337,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ Implements a recursive search for the transformer. """ + logger.warning(f"get_transformer: {python_type}") v = cls._get_transformer(python_type) if v is not None: return v @@ -2606,6 +2691,7 @@ def _register_default_type_transformers(): TypeEngine.register(BinaryIOTransformer()) TypeEngine.register(EnumTransformer()) TypeEngine.register(ProtobufTransformer()) + TypeEngine.register(LiteralTypeTransformer()) # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system # doesn't support these currently. From 6f2cdc70c067173cdd8b15797e5a7e973c7e09ce Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Thu, 24 Jul 2025 14:50:37 -0700 Subject: [PATCH 2/9] [Core feature] Add support Literal Transformer Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 85 ++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 58ba0b8556..8b121d88ca 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -19,6 +19,7 @@ from functools import lru_cache, reduce from types import GenericAlias from typing import Any, Dict, List, NamedTuple, Optional, Type, cast +from typing import Literal as TypingLiteral import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json @@ -1087,6 +1088,85 @@ def assert_type(self, t: Type[enum.Enum], v: T): raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") +class LiteralTypeTransformer(TypeTransformer[TypingLiteral]): + def __init__(self): + super().__init__("LiteralTypeTransformer", TypingLiteral) + + def get_literal_type(self, t: Type) -> LiteralType: + args = get_args(t) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + if base_type == str: + return LiteralType(simple=SimpleType.STRING) + elif base_type == int: + return LiteralType(simple=SimpleType.INTEGER) + elif base_type == float: + return LiteralType(simple=SimpleType.FLOAT) + elif base_type == bool: + return LiteralType(simple=SimpleType.BOOLEAN) + elif base_type == datetime.datetime: + return LiteralType(simple=SimpleType.DATETIME) + elif base_type == datetime.timedelta: + return LiteralType(simple=SimpleType.DURATION) + else: + raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}") + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + if expected.simple == SimpleType.STRING: + return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) + elif expected.simple == SimpleType.INTEGER: + return Literal(scalar=Scalar(primitive=Primitive(integer=python_val))) + elif expected.simple == SimpleType.FLOAT: + return Literal(scalar=Scalar(primitive=Primitive(float_value=python_val))) + elif expected.simple == SimpleType.BOOLEAN: + return Literal(scalar=Scalar(primitive=Primitive(boolean=python_val))) + elif expected.simple == SimpleType.DATETIME: + return Literal(scalar=Scalar(primitive=Primitive(datetime=python_val))) + elif expected.simple == SimpleType.DURATION: + return Literal(scalar=Scalar(primitive=Primitive(duration=python_val))) + else: + raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}") + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + if lv.scalar.primitive.string_value is not None: + return lv.scalar.primitive.string_value + elif lv.scalar.primitive.integer is not None: + return lv.scalar.primitive.integer + elif lv.scalar.primitive.float_value is not None: + return lv.scalar.primitive.float_value + elif lv.scalar.primitive.boolean is not None: + return lv.scalar.primitive.boolean + elif lv.scalar.primitive.datetime is not None: + return lv.scalar.primitive.datetime + elif lv.scalar.primitive.duration is not None: + return lv.scalar.primitive.duration + else: + raise TypeTransformerFailedError("Unsupported Literal value") + + def guess_python_type(self, literal_type: LiteralType): + if literal_type.simple == SimpleType.STRING: + return str + elif literal_type.simple == SimpleType.INTEGER: + return int + elif literal_type.simple == SimpleType.FLOAT: + return float + elif literal_type.simple == SimpleType.BOOLEAN: + return bool + elif literal_type.simple == SimpleType.DATETIME: + return datetime.datetime + elif literal_type.simple == SimpleType.DURATION: + return datetime.timedelta + else: + raise TypeTransformerFailedError(f"LiteralTypeTransformer cannot reverse {literal_type}") + + def _handle_json_schema_property( property_key: str, property_val: dict, @@ -1173,6 +1253,7 @@ class TypeEngine(typing.Generic[T]): _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore + _LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer() lazy_import_lock = threading.Lock() @classmethod @@ -1222,6 +1303,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]: # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. return cls._ENUM_TRANSFORMER + if get_origin(python_type) == TypingLiteral: + return cls._LITERAL_TYPE_TRANSFORMER + if hasattr(python_type, "__origin__"): # If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON] # or List[int] has been specifically registered; we should check for the entire type. @@ -2606,6 +2690,7 @@ def _register_default_type_transformers(): TypeEngine.register(BinaryIOTransformer()) TypeEngine.register(EnumTransformer()) TypeEngine.register(ProtobufTransformer()) + TypeEngine.register(LiteralTypeTransformer()) # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system # doesn't support these currently. From 6f10cdd14a3e6d833d6a49258e1e380fae723a35 Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Mon, 4 Aug 2025 21:20:12 -0700 Subject: [PATCH 3/9] Chage to call corresponding transformer Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 51 +++++++++++++++++------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 32e065f396..4bf909adff 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1102,67 +1102,65 @@ def get_literal_type(self, t: Type) -> LiteralType: raise TypeTransformerFailedError("All values must be of the same type") if base_type == str: - return LiteralType(simple=SimpleType.STRING) + return StrTransformer.get_literal_type(args[0]) elif base_type == int: - return LiteralType(simple=SimpleType.INTEGER) + return IntTransformer.get_literal_type(args[0]) elif base_type == float: - return LiteralType(simple=SimpleType.FLOAT) + return FloatTransformer.get_literal_type(args[0]) elif base_type == bool: - return LiteralType(simple=SimpleType.BOOLEAN) + return BoolTransformer.get_literal_type(args[0]) elif base_type == datetime.datetime: - return LiteralType(simple=SimpleType.DATETIME) + return DatetimeTransformer.get_literal_type(args[0]) elif base_type == datetime.timedelta: - return LiteralType(simple=SimpleType.DURATION) + return TimedeltaTransformer.get_literal_type(args[0]) else: raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if expected.simple == SimpleType.STRING: - return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) + return StrTransformer.to_literal(ctx, python_val, python_type, expected) elif expected.simple == SimpleType.INTEGER: - return Literal(scalar=Scalar(primitive=Primitive(integer=python_val))) + return IntTransformer.to_literal(ctx, python_val, python_type, expected) elif expected.simple == SimpleType.FLOAT: - return Literal(scalar=Scalar(primitive=Primitive(float_value=python_val))) + return FloatTransformer.to_literal(ctx, python_val, python_type, expected) elif expected.simple == SimpleType.BOOLEAN: - return Literal(scalar=Scalar(primitive=Primitive(boolean=python_val))) + return BoolTransformer.to_literal(ctx, python_val, python_type, expected) elif expected.simple == SimpleType.DATETIME: - return Literal(scalar=Scalar(primitive=Primitive(datetime=python_val))) + return DatetimeTransformer.to_literal(ctx, python_val, python_type, expected) elif expected.simple == SimpleType.DURATION: - return Literal(scalar=Scalar(primitive=Primitive(duration=python_val))) + return TimedeltaTransformer.to_literal(ctx, python_val, python_type, expected) else: raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}") def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore if lv.scalar.primitive.string_value is not None: - return lv.scalar.primitive.string_value + return StrTransformer.to_python_value(ctx, lv, str) elif lv.scalar.primitive.integer is not None: - return lv.scalar.primitive.integer + return IntTransformer.to_python_value(ctx, lv, int) elif lv.scalar.primitive.float_value is not None: - return lv.scalar.primitive.float_value + return FloatTransformer.to_python_value(ctx, lv, float) elif lv.scalar.primitive.boolean is not None: - return lv.scalar.primitive.boolean + return BoolTransformer.to_python_value(ctx, lv, bool) elif lv.scalar.primitive.datetime is not None: - return lv.scalar.primitive.datetime + return DatetimeTransformer.to_python_value(ctx, lv, datetime.datetime) elif lv.scalar.primitive.duration is not None: - return lv.scalar.primitive.duration + return TimedeltaTransformer.to_python_value(ctx, lv, datetime.timedelta) else: raise TypeTransformerFailedError("Unsupported Literal value") def guess_python_type(self, literal_type: LiteralType): if literal_type.simple == SimpleType.STRING: - return str + return StrTransformer.guess_python_type(literal_type) elif literal_type.simple == SimpleType.INTEGER: - return int + return IntTransformer.guess_python_type(literal_type) elif literal_type.simple == SimpleType.FLOAT: - return float + return FloatTransformer.guess_python_type(literal_type) elif literal_type.simple == SimpleType.BOOLEAN: - return bool + return BoolTransformer.guess_python_type(literal_type) elif literal_type.simple == SimpleType.DATETIME: - return datetime.datetime + return DatetimeTransformer.guess_python_type(literal_type) elif literal_type.simple == SimpleType.DURATION: - return datetime.timedelta + return TimedeltaTransformer.guess_python_type(literal_type) else: raise TypeTransformerFailedError(f"LiteralTypeTransformer cannot reverse {literal_type}") @@ -1337,7 +1335,6 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ Implements a recursive search for the transformer. """ - logger.warning(f"get_transformer: {python_type}") v = cls._get_transformer(python_type) if v is not None: return v From 75020d5d68e4a93a609375fb1e1542880205b567 Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Wed, 6 Aug 2025 22:17:03 -0700 Subject: [PATCH 4/9] Fix CI error by changing input and output type Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4bf909adff..e45d02ced1 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -19,7 +19,6 @@ from functools import lru_cache, reduce from types import GenericAlias from typing import Any, Dict, List, NamedTuple, Optional, Type, cast -from typing import Literal as TypingLiteral import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json @@ -1088,9 +1087,9 @@ def assert_type(self, t: Type[enum.Enum], v: T): raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") -class LiteralTypeTransformer(TypeTransformer[TypingLiteral]): +class LiteralTypeTransformer(TypeTransformer[object]): def __init__(self): - super().__init__("LiteralTypeTransformer", TypingLiteral) + super().__init__("LiteralTypeTransformer", object) def get_literal_type(self, t: Type) -> LiteralType: args = get_args(t) @@ -1116,23 +1115,23 @@ def get_literal_type(self, t: Type) -> LiteralType: else: raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}") - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: - if expected.simple == SimpleType.STRING: - return StrTransformer.to_literal(ctx, python_val, python_type, expected) - elif expected.simple == SimpleType.INTEGER: - return IntTransformer.to_literal(ctx, python_val, python_type, expected) - elif expected.simple == SimpleType.FLOAT: - return FloatTransformer.to_literal(ctx, python_val, python_type, expected) - elif expected.simple == SimpleType.BOOLEAN: - return BoolTransformer.to_literal(ctx, python_val, python_type, expected) - elif expected.simple == SimpleType.DATETIME: - return DatetimeTransformer.to_literal(ctx, python_val, python_type, expected) - elif expected.simple == SimpleType.DURATION: - return TimedeltaTransformer.to_literal(ctx, python_val, python_type, expected) + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal: + if expected.simple == SimpleType.STRING and isinstance(python_val, str): + return StrTransformer.to_literal(ctx, python_val, str, expected) + elif expected.simple == SimpleType.INTEGER and isinstance(python_val, int): + return IntTransformer.to_literal(ctx, python_val, int, expected) + elif expected.simple == SimpleType.FLOAT and isinstance(python_val, float): + return FloatTransformer.to_literal(ctx, python_val, float, expected) + elif expected.simple == SimpleType.BOOLEAN and isinstance(python_val, bool): + return BoolTransformer.to_literal(ctx, python_val, bool, expected) + elif expected.simple == SimpleType.DATETIME and isinstance(python_val, datetime.datetime): + return DatetimeTransformer.to_literal(ctx, python_val, datetime.datetime, expected) + elif expected.simple == SimpleType.DURATION and isinstance(python_val, datetime.timedelta): + return TimedeltaTransformer.to_literal(ctx, python_val, datetime.timedelta, expected) else: raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}") - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object: if lv.scalar.primitive.string_value is not None: return StrTransformer.to_python_value(ctx, lv, str) elif lv.scalar.primitive.integer is not None: @@ -1301,7 +1300,7 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]: # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. return cls._ENUM_TRANSFORMER - if get_origin(python_type) == TypingLiteral: + if get_origin(python_type) == typing.Literal: return cls._LITERAL_TYPE_TRANSFORMER if hasattr(python_type, "__origin__"): @@ -2688,7 +2687,6 @@ def _register_default_type_transformers(): TypeEngine.register(BinaryIOTransformer()) TypeEngine.register(EnumTransformer()) TypeEngine.register(ProtobufTransformer()) - TypeEngine.register(LiteralTypeTransformer()) # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system # doesn't support these currently. From 3009641a10f33a3fe647a5785f2bc884f5c82c7c Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Sat, 23 Aug 2025 01:51:08 +0800 Subject: [PATCH 5/9] Use get_transformer to replace if-else condition and add assert type Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 91 ++++++++++++++---------------------- 1 file changed, 35 insertions(+), 56 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e45d02ced1..c624671edc 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1100,68 +1100,47 @@ def get_literal_type(self, t: Type) -> LiteralType: if not all(type(a) == base_type for a in args): raise TypeTransformerFailedError("All values must be of the same type") - if base_type == str: - return StrTransformer.get_literal_type(args[0]) - elif base_type == int: - return IntTransformer.get_literal_type(args[0]) - elif base_type == float: - return FloatTransformer.get_literal_type(args[0]) - elif base_type == bool: - return BoolTransformer.get_literal_type(args[0]) - elif base_type == datetime.datetime: - return DatetimeTransformer.get_literal_type(args[0]) - elif base_type == datetime.timedelta: - return TimedeltaTransformer.get_literal_type(args[0]) - else: - raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}") + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.get_literal_type(base_type) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal: - if expected.simple == SimpleType.STRING and isinstance(python_val, str): - return StrTransformer.to_literal(ctx, python_val, str, expected) - elif expected.simple == SimpleType.INTEGER and isinstance(python_val, int): - return IntTransformer.to_literal(ctx, python_val, int, expected) - elif expected.simple == SimpleType.FLOAT and isinstance(python_val, float): - return FloatTransformer.to_literal(ctx, python_val, float, expected) - elif expected.simple == SimpleType.BOOLEAN and isinstance(python_val, bool): - return BoolTransformer.to_literal(ctx, python_val, bool, expected) - elif expected.simple == SimpleType.DATETIME and isinstance(python_val, datetime.datetime): - return DatetimeTransformer.to_literal(ctx, python_val, datetime.datetime, expected) - elif expected.simple == SimpleType.DURATION and isinstance(python_val, datetime.timedelta): - return TimedeltaTransformer.to_literal(ctx, python_val, datetime.timedelta, expected) - else: - raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}") + args = get_args(python_type) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.to_literal(ctx, python_val, python_type, expected) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object: - if lv.scalar.primitive.string_value is not None: - return StrTransformer.to_python_value(ctx, lv, str) - elif lv.scalar.primitive.integer is not None: - return IntTransformer.to_python_value(ctx, lv, int) - elif lv.scalar.primitive.float_value is not None: - return FloatTransformer.to_python_value(ctx, lv, float) - elif lv.scalar.primitive.boolean is not None: - return BoolTransformer.to_python_value(ctx, lv, bool) - elif lv.scalar.primitive.datetime is not None: - return DatetimeTransformer.to_python_value(ctx, lv, datetime.datetime) - elif lv.scalar.primitive.duration is not None: - return TimedeltaTransformer.to_python_value(ctx, lv, datetime.timedelta) - else: - raise TypeTransformerFailedError("Unsupported Literal value") + args = get_args(expected_python_type) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.to_python_value(ctx, lv, base_type) def guess_python_type(self, literal_type: LiteralType): - if literal_type.simple == SimpleType.STRING: - return StrTransformer.guess_python_type(literal_type) - elif literal_type.simple == SimpleType.INTEGER: - return IntTransformer.guess_python_type(literal_type) - elif literal_type.simple == SimpleType.FLOAT: - return FloatTransformer.guess_python_type(literal_type) - elif literal_type.simple == SimpleType.BOOLEAN: - return BoolTransformer.guess_python_type(literal_type) - elif literal_type.simple == SimpleType.DATETIME: - return DatetimeTransformer.guess_python_type(literal_type) - elif literal_type.simple == SimpleType.DURATION: - return TimedeltaTransformer.guess_python_type(literal_type) - else: - raise TypeTransformerFailedError(f"LiteralTypeTransformer cannot reverse {literal_type}") + return TypeEngine.guess_python_type(literal_type) + + def assert_type(self, python_type: Type, python_val: T): + args = get_args(python_type) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.assert_type(base_type, python_val) def _handle_json_schema_property( From 72704fd07cbb40968c58079f080a3ddac6688a31 Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Wed, 27 Aug 2025 23:29:28 +0800 Subject: [PATCH 6/9] Add get_base_type helper function, fix guess_python_type, and add unit tests Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 62 +++++++++++--------- tests/flytekit/unit/core/test_type_engine.py | 39 ++++++++++++ 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c624671edc..722f831049 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1091,7 +1091,8 @@ class LiteralTypeTransformer(TypeTransformer[object]): def __init__(self): super().__init__("LiteralTypeTransformer", object) - def get_literal_type(self, t: Type) -> LiteralType: + @classmethod + def get_base_type(cls, t: Type) -> Type: args = get_args(t) if not args: raise TypeTransformerFailedError("Literal must have at least one value") @@ -1100,45 +1101,48 @@ def get_literal_type(self, t: Type) -> LiteralType: if not all(type(a) == base_type for a in args): raise TypeTransformerFailedError("All values must be of the same type") - base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) - return base_transformer.get_literal_type(base_type) - - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal: - args = get_args(python_type) - if not args: - raise TypeTransformerFailedError("Literal must have at least one value") + return base_type - base_type = type(args[0]) - if not all(type(a) == base_type for a in args): - raise TypeTransformerFailedError("All values must be of the same type") + def get_literal_type(self, t: Type) -> LiteralType: + base_type = self.get_base_type(t) + vals = list(get_args(t)) + ann = TypeAnnotationModel(annotations={"literal_values": vals}) + if base_type is str: + simple = SimpleType.STRING + elif base_type is int: + simple = SimpleType.INTEGER + elif base_type is float: + simple = SimpleType.FLOAT + elif base_type is bool: + simple = SimpleType.BOOLEAN + elif base_type is datetime.datetime: + simple = SimpleType.DATETIME + elif base_type is datetime.timedelta: + simple = SimpleType.DURATION + else: + raise TypeTransformerFailedError(f"Unsupported type: {base_type}") + return LiteralType(simple=simple, annotation=ann) + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal: + base_type = self.get_base_type(python_type) base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) return base_transformer.to_literal(ctx, python_val, python_type, expected) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object: - args = get_args(expected_python_type) - if not args: - raise TypeTransformerFailedError("Literal must have at least one value") - - base_type = type(args[0]) - if not all(type(a) == base_type for a in args): - raise TypeTransformerFailedError("All values must be of the same type") - + base_type = self.get_base_type(expected_python_type) base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) return base_transformer.to_python_value(ctx, lv, base_type) - def guess_python_type(self, literal_type: LiteralType): - return TypeEngine.guess_python_type(literal_type) + def guess_python_type(self, literal_type: LiteralType) -> Type: + ann = getattr(literal_type, "annotation", None) + if ann and getattr(ann, "annotations", None): + vals = ann.annotations.get("literal_values") + if vals and isinstance(vals, list): + return typing.Literal[tuple(vals)] # type: ignore + raise ValueError(f"LiteralType transformer cannot reverse {literal_type}") def assert_type(self, python_type: Type, python_val: T): - args = get_args(python_type) - if not args: - raise TypeTransformerFailedError("Literal must have at least one value") - - base_type = type(args[0]) - if not all(type(a) == base_type for a in args): - raise TypeTransformerFailedError("All values must be of the same type") - + base_type = self.get_base_type(python_type) base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) return base_transformer.assert_type(base_type, python_val) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 93d5d6af67..0d92f90d90 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3851,3 +3851,42 @@ async def test_dict_transformer_annotated_type(): literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 + +def test_literal_transformer_string_type(): + # Python -> Flyte + t = typing.Literal["outcome", "income"] + lt = TypeEngine.get_transformer(t).get_literal_type(t) + assert lt.simple == SimpleType.STRING + assert lt.annotation.annotations["literal_values"] == ["outcome", "income"] + + lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt) + assert lv.scalar.primitive.string_value == "outcome" + + # Flyte -> Python (reconstruction) + pt = TypeEngine.get_transformer(t).guess_python_type(lt) + assert pt is typing.Literal["outcome", "income"] + pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt) + TypeEngine.get_transformer(pt).assert_type(pt, pv) + assert pv == "outcome" + +def test_literal_transformer_int_type(): + # Python -> Flyte + t = typing.Literal[1, 2, 3] + lt = TypeEngine.get_transformer(t).get_literal_type(t) + assert lt.simple == SimpleType.INTEGER + assert lt.annotation.annotations["literal_values"] == [1, 2, 3] + + lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt) + assert lv.scalar.primitive.integer == 1 + + # Flyte -> Python (reconstruction) + pt = TypeEngine.get_transformer(t).guess_python_type(lt) + assert pt is typing.Literal[1, 2, 3] + pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt) + TypeEngine.get_transformer(pt).assert_type(pt, pv) + assert pv == 1 + +def test_literal_transformer_mixed_base_types(): + t = typing.Literal["a", 1] + with pytest.raises(TypeTransformerFailedError): + TypeEngine.get_transformer(t).get_literal_type(t) From e0aa3d8ca95efd7f821ba8fc17a703639f7797fa Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Tue, 23 Sep 2025 21:19:01 -0500 Subject: [PATCH 7/9] Fix typo and add some tests Signed-off-by: Barry Wu --- flytekit/core/type_engine.py | 7 ++----- tests/flytekit/unit/core/test_type_engine.py | 2 ++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 722f831049..d8148158ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1134,11 +1134,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return base_transformer.to_python_value(ctx, lv, base_type) def guess_python_type(self, literal_type: LiteralType) -> Type: - ann = getattr(literal_type, "annotation", None) - if ann and getattr(ann, "annotations", None): - vals = ann.annotations.get("literal_values") - if vals and isinstance(vals, list): - return typing.Literal[tuple(vals)] # type: ignore + if literal_type.annotation and literal_type.annotation.annotations: + return typing.Literal[tuple(literal_type.annotation.annotations.get("literal_values"))] # type: ignore raise ValueError(f"LiteralType transformer cannot reverse {literal_type}") def assert_type(self, python_type: Type, python_val: T): diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0d92f90d90..280ad452ee 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3858,6 +3858,7 @@ def test_literal_transformer_string_type(): lt = TypeEngine.get_transformer(t).get_literal_type(t) assert lt.simple == SimpleType.STRING assert lt.annotation.annotations["literal_values"] == ["outcome", "income"] + assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl()) lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt) assert lv.scalar.primitive.string_value == "outcome" @@ -3875,6 +3876,7 @@ def test_literal_transformer_int_type(): lt = TypeEngine.get_transformer(t).get_literal_type(t) assert lt.simple == SimpleType.INTEGER assert lt.annotation.annotations["literal_values"] == [1, 2, 3] + assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl()) lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt) assert lv.scalar.primitive.integer == 1 From b6f40ce4f57dc364061224bf3f90dcd6d46e93ae Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Fri, 26 Sep 2025 19:39:28 -0500 Subject: [PATCH 8/9] Fix conflicts Signed-off-by: Barry Wu --- tests/flytekit/unit/core/test_type_engine.py | 142 +++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index e82479f5a9..cb5966871f 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3852,6 +3852,148 @@ async def test_dict_transformer_annotated_type(): literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 +@pytest.fixture(autouse=True) +def clear_type_engine_cache(): + """Clear TypeEngine cache before and after each test""" + TypeEngine._LITERAL_CACHE.clear() + yield + TypeEngine._LITERAL_CACHE.clear() + +def test_type_engine_cache_with_list(): + ctx = FlyteContext.current_context() + python_val = [1, 2, 3, 4, 5] + python_type = typing.List[int] + expected = TypeEngine.to_literal_type(python_type) + list_transformer = TypeEngine.get_transformer(typing.List[int]) + with mock.patch.object(list_transformer, 'async_to_literal', + wraps=list_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + key = TypeEngine._get_literal_cache_key(python_val, python_type) + assert key is not None + assert key in TypeEngine._LITERAL_CACHE + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + # Verify async_to_literal was only called once + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + + # Test with different data - should not use cache + different_val = [2, 1, 3, 4, 5] + literal3 = TypeEngine.to_literal(ctx, different_val, python_type, expected) + key_different = TypeEngine._get_literal_cache_key(different_val, python_type) + + assert key_different is not key + assert key_different is not None + assert key_different in TypeEngine._LITERAL_CACHE + + # Verify different literals are different objects + assert literal1 is not literal3 + + # Add many different values to test cache size limit + for i in range(200): # More than the default maxsize of 128 + test_val = [i, i+1, i+2] + test_type = typing.List[int] + test_expected = TypeEngine.to_literal_type(test_type) + TypeEngine.to_literal(ctx, test_val, test_type, test_expected) + + # Cache should not exceed maxsize + assert len(TypeEngine._LITERAL_CACHE) == 128 + +def test_type_engine_cache_with_dict(): + ctx = FlyteContext.current_context() + python_val = {"a": [1, 2, 3]} + python_type = typing.Dict[str, typing.List[int]] + expected = TypeEngine.to_literal_type(python_type) + dict_transformer = TypeEngine.get_transformer(typing.Dict[str, typing.List[int]]) + with mock.patch.object(dict_transformer, 'async_to_literal', + wraps=dict_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + key = TypeEngine._get_literal_cache_key(python_val, python_type) + assert key is not None + assert key in TypeEngine._LITERAL_CACHE + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + # Verify async_to_literal was only called once + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + +def test_make_key_with_annotated_types(): + # Test with Annotated type + annotated_val = [1, 2, 3] + annotated_type = typing.Annotated[typing.List[int], "test_annotation"] + + key = TypeEngine._get_literal_cache_key(annotated_val, annotated_type) + key_without_annotation = TypeEngine._get_literal_cache_key(annotated_val, typing.List[int]) + # Should handle Annotated types correctly + assert key is not None + assert key_without_annotation is not None + assert key != key_without_annotation + +def test_type_engine_cache_with_pandas(): + pd = pytest.importorskip("pandas") + ctx = FlyteContext.current_context() + # Create DataFrame + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_type = pd.DataFrame + df_expected = TypeEngine.to_literal_type(df_type) + + # Get the transformer for DataFrame + df_transformer = TypeEngine._REGISTRY[pd.DataFrame] + + # Mock the async_to_literal method with wraps to track calls + with mock.patch.object(df_transformer, 'async_to_literal', + wraps=df_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, df, df_type, df_expected) + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, df, df_type, df_expected) + + # Verify async_to_literal was called + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + +def test_type_engine_cache_with_flytefile(): + + transformer = TypeEngine.get_transformer(FlyteFile) + ctx = FlyteContext.current_context() + + temp_dir = tempfile.mkdtemp(prefix="temp_example_") + file_path = os.path.join(temp_dir, "file.txt") + with open(file_path, "w") as file1: + file1.write("hello world") + + lt = TypeEngine.to_literal_type(FlyteFile) + + # Mock the file upload + with mock.patch.object(transformer, 'async_to_literal', + wraps=transformer.async_to_literal) as mock_async_to_literal: + + # Test 1: Upload local file to remote + lv1 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt) + + # Second call with same DataFrame + lv2 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt) + + # Verify async_to_literal was called + assert mock_async_to_literal.call_count == 1 + + assert lv1 is lv2 + def test_literal_transformer_string_type(): # Python -> Flyte t = typing.Literal["outcome", "income"] From bcd136fb6bbc0954087a8d3dc8b1abbdbcdbb2ed Mon Sep 17 00:00:00 2001 From: Barry Wu Date: Fri, 26 Sep 2025 19:45:07 -0500 Subject: [PATCH 9/9] Fix lint Signed-off-by: Barry Wu --- tests/flytekit/unit/core/test_type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index cb5966871f..ef139b6701 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -4033,4 +4033,4 @@ def test_literal_transformer_int_type(): def test_literal_transformer_mixed_base_types(): t = typing.Literal["a", 1] with pytest.raises(TypeTransformerFailedError): - TypeEngine.get_transformer(t).get_literal_type(t) \ No newline at end of file + TypeEngine.get_transformer(t).get_literal_type(t)