diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2895147a06..6b3760056e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1088,6 +1088,63 @@ def assert_type(self, t: Type[enum.Enum], v: T): raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}") +class LiteralTypeTransformer(TypeTransformer[object]): + def __init__(self): + super().__init__("LiteralTypeTransformer", object) + + @classmethod + def get_base_type(cls, t: Type) -> Type: + args = get_args(t) + if not args: + raise TypeTransformerFailedError("Literal must have at least one value") + + base_type = type(args[0]) + if not all(type(a) == base_type for a in args): + raise TypeTransformerFailedError("All values must be of the same type") + + return base_type + + def get_literal_type(self, t: Type) -> LiteralType: + base_type = self.get_base_type(t) + vals = list(get_args(t)) + ann = TypeAnnotationModel(annotations={"literal_values": vals}) + if base_type is str: + simple = SimpleType.STRING + elif base_type is int: + simple = SimpleType.INTEGER + elif base_type is float: + simple = SimpleType.FLOAT + elif base_type is bool: + simple = SimpleType.BOOLEAN + elif base_type is datetime.datetime: + simple = SimpleType.DATETIME + elif base_type is datetime.timedelta: + simple = SimpleType.DURATION + else: + raise TypeTransformerFailedError(f"Unsupported type: {base_type}") + return LiteralType(simple=simple, annotation=ann) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal: + base_type = self.get_base_type(python_type) + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.to_literal(ctx, python_val, python_type, expected) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object: + base_type = self.get_base_type(expected_python_type) + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.to_python_value(ctx, lv, base_type) + + def guess_python_type(self, literal_type: LiteralType) -> Type: + if literal_type.annotation and literal_type.annotation.annotations: + return typing.Literal[tuple(literal_type.annotation.annotations.get("literal_values"))] # type: ignore + raise ValueError(f"LiteralType transformer cannot reverse {literal_type}") + + def assert_type(self, python_type: Type, python_val: T): + base_type = self.get_base_type(python_type) + base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type) + return base_transformer.assert_type(base_type, python_val) + + def _handle_json_schema_property( property_key: str, property_val: dict, @@ -1174,6 +1231,7 @@ class TypeEngine(typing.Generic[T]): _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore + _LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer() lazy_import_lock = threading.Lock() _LITERAL_CACHE: LRUCache = LRUCache(maxsize=128) @@ -1224,6 +1282,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]: # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used. return cls._ENUM_TRANSFORMER + if get_origin(python_type) == typing.Literal: + return cls._LITERAL_TYPE_TRANSFORMER + if hasattr(python_type, "__origin__"): # If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON] # or List[int] has been specifically registered; we should check for the entire type. diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 7e04ab0214..ef139b6701 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3993,3 +3993,44 @@ def test_type_engine_cache_with_flytefile(): assert mock_async_to_literal.call_count == 1 assert lv1 is lv2 + +def test_literal_transformer_string_type(): + # Python -> Flyte + t = typing.Literal["outcome", "income"] + lt = TypeEngine.get_transformer(t).get_literal_type(t) + assert lt.simple == SimpleType.STRING + assert lt.annotation.annotations["literal_values"] == ["outcome", "income"] + assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl()) + + lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt) + assert lv.scalar.primitive.string_value == "outcome" + + # Flyte -> Python (reconstruction) + pt = TypeEngine.get_transformer(t).guess_python_type(lt) + assert pt is typing.Literal["outcome", "income"] + pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt) + TypeEngine.get_transformer(pt).assert_type(pt, pv) + assert pv == "outcome" + +def test_literal_transformer_int_type(): + # Python -> Flyte + t = typing.Literal[1, 2, 3] + lt = TypeEngine.get_transformer(t).get_literal_type(t) + assert lt.simple == SimpleType.INTEGER + assert lt.annotation.annotations["literal_values"] == [1, 2, 3] + assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl()) + + lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt) + assert lv.scalar.primitive.integer == 1 + + # Flyte -> Python (reconstruction) + pt = TypeEngine.get_transformer(t).guess_python_type(lt) + assert pt is typing.Literal[1, 2, 3] + pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt) + TypeEngine.get_transformer(pt).assert_type(pt, pv) + assert pv == 1 + +def test_literal_transformer_mixed_base_types(): + t = typing.Literal["a", 1] + with pytest.raises(TypeTransformerFailedError): + TypeEngine.get_transformer(t).get_literal_type(t)