@@ -1088,6 +1088,63 @@ def assert_type(self, t: Type[enum.Enum], v: T):
10881088 raise TypeTransformerFailedError (f"Value { v } is not in Enum { t } " )
10891089
10901090
1091+ class LiteralTypeTransformer (TypeTransformer [object ]):
1092+ def __init__ (self ):
1093+ super ().__init__ ("LiteralTypeTransformer" , object )
1094+
1095+ @classmethod
1096+ def get_base_type (cls , t : Type ) -> Type :
1097+ args = get_args (t )
1098+ if not args :
1099+ raise TypeTransformerFailedError ("Literal must have at least one value" )
1100+
1101+ base_type = type (args [0 ])
1102+ if not all (type (a ) == base_type for a in args ):
1103+ raise TypeTransformerFailedError ("All values must be of the same type" )
1104+
1105+ return base_type
1106+
1107+ def get_literal_type (self , t : Type ) -> LiteralType :
1108+ base_type = self .get_base_type (t )
1109+ vals = list (get_args (t ))
1110+ ann = TypeAnnotationModel (annotations = {"literal_values" : vals })
1111+ if base_type is str :
1112+ simple = SimpleType .STRING
1113+ elif base_type is int :
1114+ simple = SimpleType .INTEGER
1115+ elif base_type is float :
1116+ simple = SimpleType .FLOAT
1117+ elif base_type is bool :
1118+ simple = SimpleType .BOOLEAN
1119+ elif base_type is datetime .datetime :
1120+ simple = SimpleType .DATETIME
1121+ elif base_type is datetime .timedelta :
1122+ simple = SimpleType .DURATION
1123+ else :
1124+ raise TypeTransformerFailedError (f"Unsupported type: { base_type } " )
1125+ return LiteralType (simple = simple , annotation = ann )
1126+
1127+ def to_literal (self , ctx : FlyteContext , python_val : T , python_type : Type , expected : LiteralType ) -> Literal :
1128+ base_type = self .get_base_type (python_type )
1129+ base_transformer : TypeTransformer [object ] = TypeEngine .get_transformer (base_type )
1130+ return base_transformer .to_literal (ctx , python_val , python_type , expected )
1131+
1132+ def to_python_value (self , ctx : FlyteContext , lv : Literal , expected_python_type : Type ) -> object :
1133+ base_type = self .get_base_type (expected_python_type )
1134+ base_transformer : TypeTransformer [object ] = TypeEngine .get_transformer (base_type )
1135+ return base_transformer .to_python_value (ctx , lv , base_type )
1136+
1137+ def guess_python_type (self , literal_type : LiteralType ) -> Type :
1138+ if literal_type .annotation and literal_type .annotation .annotations :
1139+ return typing .Literal [tuple (literal_type .annotation .annotations .get ("literal_values" ))] # type: ignore
1140+ raise ValueError (f"LiteralType transformer cannot reverse { literal_type } " )
1141+
1142+ def assert_type (self , python_type : Type , python_val : T ):
1143+ base_type = self .get_base_type (python_type )
1144+ base_transformer : TypeTransformer [object ] = TypeEngine .get_transformer (base_type )
1145+ return base_transformer .assert_type (base_type , python_val )
1146+
1147+
10911148def _handle_json_schema_property (
10921149 property_key : str ,
10931150 property_val : dict ,
@@ -1174,6 +1231,7 @@ class TypeEngine(typing.Generic[T]):
11741231 _RESTRICTED_TYPES : typing .List [type ] = []
11751232 _DATACLASS_TRANSFORMER : TypeTransformer = DataclassTransformer () # type: ignore
11761233 _ENUM_TRANSFORMER : TypeTransformer = EnumTransformer () # type: ignore
1234+ _LITERAL_TYPE_TRANSFORMER : TypeTransformer = LiteralTypeTransformer ()
11771235 lazy_import_lock = threading .Lock ()
11781236 _LITERAL_CACHE : LRUCache = LRUCache (maxsize = 128 )
11791237
@@ -1224,6 +1282,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
12241282 # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
12251283 return cls ._ENUM_TRANSFORMER
12261284
1285+ if get_origin (python_type ) == typing .Literal :
1286+ return cls ._LITERAL_TYPE_TRANSFORMER
1287+
12271288 if hasattr (python_type , "__origin__" ):
12281289 # If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
12291290 # or List[int] has been specifically registered; we should check for the entire type.
0 commit comments