@@ -878,7 +878,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
878878 return list (map (lambda x : self ._fix_val_int (ListTransformer .get_sub_type (t ), x ), val ))
879879
880880 if isinstance (val , dict ):
881- ktype , vtype = DictTransformer .extract_types_or_metadata (t )
881+ ktype , vtype = DictTransformer .extract_types (t )
882882 # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}})
883883 return {
884884 self ._fix_val_int (cast (type , ktype ), k ): self ._fix_val_int (cast (type , vtype ), v ) for k , v in val .items ()
@@ -2018,7 +2018,7 @@ def __init__(self):
20182018 super ().__init__ ("Typed Dict" , dict )
20192019
20202020 @staticmethod
2021- def extract_types_or_metadata (t : Optional [Type [dict ]]) -> typing .Tuple :
2021+ def extract_types (t : Optional [Type [dict ]]) -> typing .Tuple :
20222022 _origin = get_origin (t )
20232023 _args = get_args (t )
20242024 if _origin is not None :
@@ -2031,8 +2031,12 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
20312031 raise ValueError (
20322032 f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. { t } cannot be parsed."
20332033 )
2034- if _origin in [ dict , Annotated ] and _args is not None :
2034+ if _origin is dict and _args is not None :
20352035 return _args # type: ignore
2036+ elif _origin is Annotated :
2037+ return DictTransformer .extract_types (_args [0 ])
2038+ else :
2039+ raise ValueError (f"Trying to extract dictionary type information from a non-dict type { t } " )
20362040 return None , None
20372041
20382042 @staticmethod
@@ -2099,31 +2103,24 @@ async def dict_to_binary_literal(
20992103 raise TypeTransformerFailedError (f"Cannot convert `{ v } ` to Flyte Literal.\n " f"Error Message: { e } " )
21002104
21012105 @staticmethod
2102- def is_pickle (python_type : Type [dict ]) -> typing .Tuple [bool , Type ]:
2103- base_type , * metadata = DictTransformer .extract_types_or_metadata (python_type )
2106+ def is_pickle (python_type : Type [dict ]) -> bool :
2107+ _origin = get_origin (python_type )
2108+ metadata : typing .Tuple = ()
2109+ if _origin is Annotated :
2110+ metadata = get_args (python_type )[1 :]
21042111
21052112 for each_metadata in metadata :
21062113 if isinstance (each_metadata , OrderedDict ):
21072114 allow_pickle = each_metadata .get ("allow_pickle" , False )
2108- return allow_pickle , base_type
2109-
2110- return False , base_type
2115+ return allow_pickle
21112116
2112- @staticmethod
2113- def dict_types (python_type : Type ) -> typing .Tuple [typing .Any , ...]:
2114- if get_origin (python_type ) is Annotated :
2115- base_type , * _ = DictTransformer .extract_types_or_metadata (python_type )
2116- tp = get_args (base_type )
2117- else :
2118- tp = DictTransformer .extract_types_or_metadata (python_type )
2119-
2120- return tp
2117+ return False
21212118
21222119 def get_literal_type (self , t : Type [dict ]) -> LiteralType :
21232120 """
21242121 Transforms a native python dictionary to a flyte-specific ``LiteralType``
21252122 """
2126- tp = self . dict_types (t )
2123+ tp = DictTransformer . extract_types (t )
21272124
21282125 if tp :
21292126 if tp [0 ] == str :
@@ -2144,10 +2141,9 @@ async def async_to_literal(
21442141 raise TypeTransformerFailedError ("Expected a dict" )
21452142
21462143 allow_pickle = False
2147- base_type = None
21482144
21492145 if get_origin (python_type ) is Annotated :
2150- allow_pickle , base_type = DictTransformer .is_pickle (python_type )
2146+ allow_pickle = DictTransformer .is_pickle (python_type )
21512147
21522148 if expected and expected .simple and expected .simple == SimpleType .STRUCT :
21532149 if str2bool (os .getenv (FLYTE_USE_OLD_DC_FORMAT )):
@@ -2160,11 +2156,7 @@ async def async_to_literal(
21602156 raise ValueError ("Flyte MapType expects all keys to be strings" )
21612157 # TODO: log a warning for Annotated objects that contain HashMethod
21622158
2163- if base_type :
2164- _ , v_type = get_args (base_type )
2165- else :
2166- _ , v_type = self .extract_types_or_metadata (python_type )
2167-
2159+ _ , v_type = self .extract_types (python_type )
21682160 lit_map [k ] = TypeEngine .async_to_literal (ctx , v , cast (type , v_type ), expected .map_value_type )
21692161 vals = await _run_coros_in_chunks ([c for c in lit_map .values ()], batch_size = _TYPE_ENGINE_COROS_BATCH_SIZE )
21702162 for idx , k in zip (range (len (vals )), lit_map .keys ()):
@@ -2177,9 +2169,9 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
21772169 return self .from_binary_idl (lv .scalar .binary , expected_python_type ) # type: ignore
21782170
21792171 if lv and lv .map and lv .map .literals is not None :
2180- tp = self . dict_types (expected_python_type )
2172+ tp = DictTransformer . extract_types (expected_python_type )
21812173
2182- if tp is None or tp [0 ] is None :
2174+ if tp is None or len ( tp ) == 0 or tp [0 ] is None :
21832175 raise TypeError (
21842176 "TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given "
21852177 "dictionary does not have sub-type hints or they do not match with the originating dictionary "
0 commit comments