Skip to content

Commit 23de1ab

Browse files
BarryWu0812Barry Wu
andauthored
[Core feature] Add support Literal Transformer (#3304)
Signed-off-by: Barry Wu <a0987818905@gmail.com> Co-authored-by: Barry Wu <barry@wukanglideMacBook-Pro.local>
1 parent 4d21550 commit 23de1ab

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

flytekit/core/type_engine.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,63 @@ def assert_type(self, t: Type[enum.Enum], v: T):
10881088
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")
10891089

10901090

1091+
class LiteralTypeTransformer(TypeTransformer[object]):
1092+
def __init__(self):
1093+
super().__init__("LiteralTypeTransformer", object)
1094+
1095+
@classmethod
1096+
def get_base_type(cls, t: Type) -> Type:
1097+
args = get_args(t)
1098+
if not args:
1099+
raise TypeTransformerFailedError("Literal must have at least one value")
1100+
1101+
base_type = type(args[0])
1102+
if not all(type(a) == base_type for a in args):
1103+
raise TypeTransformerFailedError("All values must be of the same type")
1104+
1105+
return base_type
1106+
1107+
def get_literal_type(self, t: Type) -> LiteralType:
1108+
base_type = self.get_base_type(t)
1109+
vals = list(get_args(t))
1110+
ann = TypeAnnotationModel(annotations={"literal_values": vals})
1111+
if base_type is str:
1112+
simple = SimpleType.STRING
1113+
elif base_type is int:
1114+
simple = SimpleType.INTEGER
1115+
elif base_type is float:
1116+
simple = SimpleType.FLOAT
1117+
elif base_type is bool:
1118+
simple = SimpleType.BOOLEAN
1119+
elif base_type is datetime.datetime:
1120+
simple = SimpleType.DATETIME
1121+
elif base_type is datetime.timedelta:
1122+
simple = SimpleType.DURATION
1123+
else:
1124+
raise TypeTransformerFailedError(f"Unsupported type: {base_type}")
1125+
return LiteralType(simple=simple, annotation=ann)
1126+
1127+
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal:
1128+
base_type = self.get_base_type(python_type)
1129+
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
1130+
return base_transformer.to_literal(ctx, python_val, python_type, expected)
1131+
1132+
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object:
1133+
base_type = self.get_base_type(expected_python_type)
1134+
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
1135+
return base_transformer.to_python_value(ctx, lv, base_type)
1136+
1137+
def guess_python_type(self, literal_type: LiteralType) -> Type:
1138+
if literal_type.annotation and literal_type.annotation.annotations:
1139+
return typing.Literal[tuple(literal_type.annotation.annotations.get("literal_values"))] # type: ignore
1140+
raise ValueError(f"LiteralType transformer cannot reverse {literal_type}")
1141+
1142+
def assert_type(self, python_type: Type, python_val: T):
1143+
base_type = self.get_base_type(python_type)
1144+
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
1145+
return base_transformer.assert_type(base_type, python_val)
1146+
1147+
10911148
def _handle_json_schema_property(
10921149
property_key: str,
10931150
property_val: dict,
@@ -1174,6 +1231,7 @@ class TypeEngine(typing.Generic[T]):
11741231
_RESTRICTED_TYPES: typing.List[type] = []
11751232
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
11761233
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
1234+
_LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer()
11771235
lazy_import_lock = threading.Lock()
11781236
_LITERAL_CACHE: LRUCache = LRUCache(maxsize=128)
11791237

@@ -1224,6 +1282,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
12241282
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
12251283
return cls._ENUM_TRANSFORMER
12261284

1285+
if get_origin(python_type) == typing.Literal:
1286+
return cls._LITERAL_TYPE_TRANSFORMER
1287+
12271288
if hasattr(python_type, "__origin__"):
12281289
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
12291290
# or List[int] has been specifically registered; we should check for the entire type.

tests/flytekit/unit/core/test_type_engine.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3993,3 +3993,44 @@ def test_type_engine_cache_with_flytefile():
39933993
assert mock_async_to_literal.call_count == 1
39943994

39953995
assert lv1 is lv2
3996+
3997+
def test_literal_transformer_string_type():
3998+
# Python -> Flyte
3999+
t = typing.Literal["outcome", "income"]
4000+
lt = TypeEngine.get_transformer(t).get_literal_type(t)
4001+
assert lt.simple == SimpleType.STRING
4002+
assert lt.annotation.annotations["literal_values"] == ["outcome", "income"]
4003+
assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl())
4004+
4005+
lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt)
4006+
assert lv.scalar.primitive.string_value == "outcome"
4007+
4008+
# Flyte -> Python (reconstruction)
4009+
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
4010+
assert pt is typing.Literal["outcome", "income"]
4011+
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
4012+
TypeEngine.get_transformer(pt).assert_type(pt, pv)
4013+
assert pv == "outcome"
4014+
4015+
def test_literal_transformer_int_type():
4016+
# Python -> Flyte
4017+
t = typing.Literal[1, 2, 3]
4018+
lt = TypeEngine.get_transformer(t).get_literal_type(t)
4019+
assert lt.simple == SimpleType.INTEGER
4020+
assert lt.annotation.annotations["literal_values"] == [1, 2, 3]
4021+
assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl())
4022+
4023+
lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt)
4024+
assert lv.scalar.primitive.integer == 1
4025+
4026+
# Flyte -> Python (reconstruction)
4027+
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
4028+
assert pt is typing.Literal[1, 2, 3]
4029+
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
4030+
TypeEngine.get_transformer(pt).assert_type(pt, pv)
4031+
assert pv == 1
4032+
4033+
def test_literal_transformer_mixed_base_types():
4034+
t = typing.Literal["a", 1]
4035+
with pytest.raises(TypeTransformerFailedError):
4036+
TypeEngine.get_transformer(t).get_literal_type(t)

0 commit comments

Comments
 (0)