Skip to content

Commit d732d6f

Browse files
Fix dict annotations (#3123)
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent a38fc5c commit d732d6f

File tree

5 files changed

+103
-50
lines changed

5 files changed

+103
-50
lines changed

flytekit/core/promise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ async def binding_data_from_python_std(
947947
if transformer_override and hasattr(transformer_override, "extract_types_or_metadata"):
948948
_, v_type = transformer_override.extract_types_or_metadata(t_value_type) # type: ignore
949949
else:
950-
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type) # type: ignore
950+
_, v_type = DictTransformer.extract_types(cast(typing.Type[dict], t_value_type))
951951
m = _literals_models.BindingDataMap(
952952
bindings={
953953
k: await binding_data_from_python_std(

flytekit/core/type_engine.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
878878
return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val))
879879

880880
if isinstance(val, dict):
881-
ktype, vtype = DictTransformer.extract_types_or_metadata(t)
881+
ktype, vtype = DictTransformer.extract_types(t)
882882
# Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}})
883883
return {
884884
self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items()
@@ -2018,7 +2018,7 @@ def __init__(self):
20182018
super().__init__("Typed Dict", dict)
20192019

20202020
@staticmethod
2021-
def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
2021+
def extract_types(t: Optional[Type[dict]]) -> typing.Tuple:
20222022
_origin = get_origin(t)
20232023
_args = get_args(t)
20242024
if _origin is not None:
@@ -2031,8 +2031,12 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
20312031
raise ValueError(
20322032
f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed."
20332033
)
2034-
if _origin in [dict, Annotated] and _args is not None:
2034+
if _origin is dict and _args is not None:
20352035
return _args # type: ignore
2036+
elif _origin is Annotated:
2037+
return DictTransformer.extract_types(_args[0])
2038+
else:
2039+
raise ValueError(f"Trying to extract dictionary type information from a non-dict type {t}")
20362040
return None, None
20372041

20382042
@staticmethod
@@ -2099,31 +2103,24 @@ async def dict_to_binary_literal(
20992103
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")
21002104

21012105
@staticmethod
2102-
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
2103-
base_type, *metadata = DictTransformer.extract_types_or_metadata(python_type)
2106+
def is_pickle(python_type: Type[dict]) -> bool:
2107+
_origin = get_origin(python_type)
2108+
metadata: typing.Tuple = ()
2109+
if _origin is Annotated:
2110+
metadata = get_args(python_type)[1:]
21042111

21052112
for each_metadata in metadata:
21062113
if isinstance(each_metadata, OrderedDict):
21072114
allow_pickle = each_metadata.get("allow_pickle", False)
2108-
return allow_pickle, base_type
2109-
2110-
return False, base_type
2115+
return allow_pickle
21112116

2112-
@staticmethod
2113-
def dict_types(python_type: Type) -> typing.Tuple[typing.Any, ...]:
2114-
if get_origin(python_type) is Annotated:
2115-
base_type, *_ = DictTransformer.extract_types_or_metadata(python_type)
2116-
tp = get_args(base_type)
2117-
else:
2118-
tp = DictTransformer.extract_types_or_metadata(python_type)
2119-
2120-
return tp
2117+
return False
21212118

21222119
def get_literal_type(self, t: Type[dict]) -> LiteralType:
21232120
"""
21242121
Transforms a native python dictionary to a flyte-specific ``LiteralType``
21252122
"""
2126-
tp = self.dict_types(t)
2123+
tp = DictTransformer.extract_types(t)
21272124

21282125
if tp:
21292126
if tp[0] == str:
@@ -2144,10 +2141,9 @@ async def async_to_literal(
21442141
raise TypeTransformerFailedError("Expected a dict")
21452142

21462143
allow_pickle = False
2147-
base_type = None
21482144

21492145
if get_origin(python_type) is Annotated:
2150-
allow_pickle, base_type = DictTransformer.is_pickle(python_type)
2146+
allow_pickle = DictTransformer.is_pickle(python_type)
21512147

21522148
if expected and expected.simple and expected.simple == SimpleType.STRUCT:
21532149
if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)):
@@ -2160,11 +2156,7 @@ async def async_to_literal(
21602156
raise ValueError("Flyte MapType expects all keys to be strings")
21612157
# TODO: log a warning for Annotated objects that contain HashMethod
21622158

2163-
if base_type:
2164-
_, v_type = get_args(base_type)
2165-
else:
2166-
_, v_type = self.extract_types_or_metadata(python_type)
2167-
2159+
_, v_type = self.extract_types(python_type)
21682160
lit_map[k] = TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
21692161
vals = await _run_coros_in_chunks([c for c in lit_map.values()], batch_size=_TYPE_ENGINE_COROS_BATCH_SIZE)
21702162
for idx, k in zip(range(len(vals)), lit_map.keys()):
@@ -2177,9 +2169,9 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
21772169
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
21782170

21792171
if lv and lv.map and lv.map.literals is not None:
2180-
tp = self.dict_types(expected_python_type)
2172+
tp = DictTransformer.extract_types(expected_python_type)
21812173

2182-
if tp is None or tp[0] is None:
2174+
if tp is None or len(tp) == 0 or tp[0] is None:
21832175
raise TypeError(
21842176
"TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given "
21852177
"dictionary does not have sub-type hints or they do not match with the originating dictionary "

tests/flytekit/unit/core/test_annotated_bindings.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import asyncio
2+
from flytekit.core.artifact import Artifact
23
from dataclasses import dataclass
3-
from typing import List, Optional, TypeVar, Type, Tuple
4+
from typing import List, Optional, TypeVar, Type, Tuple, Union
45

6+
from pydantic import BaseModel
57
from typing_extensions import Annotated
68

7-
from flytekit import task, workflow
9+
from flytekit import task
10+
from flytekit import workflow
811
from flytekit.configuration import SerializationSettings, ImageConfig, Image
9-
from flytekit.core.context_manager import FlyteContextManager, ExecutionState, FlyteContext
12+
from flytekit.core.context_manager import FlyteContext
13+
from flytekit.core.context_manager import FlyteContextManager, ExecutionState
1014
from flytekit.core.dynamic_workflow_task import dynamic
1115
from flytekit.core.type_engine import SimpleTransformer
1216
from flytekit.core.type_engine import TypeEngine, AsyncTypeTransformer, TypeTransformerFailedError
@@ -339,3 +343,77 @@ def t0(ii: Annotated[MyStr, tf]) -> None:
339343
print(dynamic_job_spec)
340344

341345
del TypeEngine._REGISTRY[MyStr]
346+
347+
348+
# stand-in for prophet.Prophet models
349+
class Prophet:
350+
def __init__(self, **kwargs):
351+
self.kwargs = kwargs
352+
353+
354+
def test_annotated_dynamic():
355+
# Define a dummy Pydantic dataclass
356+
class Config(BaseModel):
357+
forecast_horizon: int = 30
358+
changepoint_prior_scale: float = 0.05
359+
seasonality_mode: str = "multiplicative"
360+
361+
@task
362+
def get_config() -> Config:
363+
return Config()
364+
365+
def dt_function():
366+
# Create dummy Prophet models
367+
best_models = {
368+
"model_1": Prophet(changepoint_prior_scale=0.05),
369+
"model_2": Prophet(changepoint_prior_scale=0.1)
370+
}
371+
372+
# Instantiate config
373+
cfg = get_config()
374+
375+
# Create final model dictionary
376+
final_model = {
377+
"model": best_models,
378+
"config": cfg
379+
}
380+
381+
return final_model
382+
383+
a = Artifact(name="my_model")
384+
385+
@dynamic
386+
def dt1() -> Annotated[dict[str, Union[Config, dict[str, Prophet]]], a]:
387+
return dt_function()
388+
389+
@dynamic
390+
def dt_plain() -> dict[str, Union[Config, dict[str, Prophet]]]:
391+
return dt_function()
392+
393+
ss = SerializationSettings(
394+
project="test_proj",
395+
domain="test_domain",
396+
version="abc",
397+
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
398+
env={},
399+
)
400+
401+
with FlyteContextManager.with_context(
402+
FlyteContextManager.current_context().with_serialization_settings(ss)
403+
) as ctx:
404+
new_exc_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)
405+
with FlyteContextManager.with_context(ctx.with_execution_state(new_exc_state)):
406+
dynamic_job_spec_dt1 = dt1.compile_into_workflow(ctx, dt1._task_function, )
407+
408+
with FlyteContextManager.with_context(
409+
FlyteContextManager.current_context().with_serialization_settings(ss)
410+
) as ctx:
411+
new_exc_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)
412+
with FlyteContextManager.with_context(ctx.with_execution_state(new_exc_state)):
413+
dynamic_job_spec_plain = dt_plain.compile_into_workflow(ctx, dt_plain._task_function, )
414+
415+
assert len(dynamic_job_spec_dt1.nodes) == 1
416+
assert len(dynamic_job_spec_plain.nodes) == 1
417+
418+
assert len(dynamic_job_spec_dt1.outputs[0].binding.map.bindings["model"].map.bindings) == 2
419+
assert len(dynamic_job_spec_plain.outputs[0].binding.map.bindings["model"].map.bindings) == 2

tests/flytekit/unit/core/test_generice_idl_type_engine.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,23 +2858,6 @@ def test_get_underlying_type(t, expected):
28582858
assert get_underlying_type(t) == expected
28592859

28602860

2861-
@pytest.mark.parametrize(
2862-
"t,expected",
2863-
[
2864-
(None, (None, None)),
2865-
(typing.Dict, ()),
2866-
(typing.Dict[str, str], (str, str)),
2867-
(
2868-
Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)],
2869-
(typing.Dict[str, str], kwtypes(allow_pickle=True)),
2870-
),
2871-
(typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int)),
2872-
],
2873-
)
2874-
def test_dict_get(t, expected):
2875-
assert DictTransformer.extract_types_or_metadata(t) == expected
2876-
2877-
28782861
def test_DataclassTransformer_get_literal_type():
28792862
@dataclass
28802863
class MyDataClassMashumaro(DataClassJsonMixin):

tests/flytekit/unit/core/test_type_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,13 +2871,13 @@ def test_get_underlying_type(t, expected):
28712871
(typing.Dict[str, str], (str, str)),
28722872
(
28732873
Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)],
2874-
(typing.Dict[str, str], kwtypes(allow_pickle=True)),
2874+
(str, str),
28752875
),
28762876
(typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int)),
28772877
],
28782878
)
28792879
def test_dict_get(t, expected):
2880-
assert DictTransformer.extract_types_or_metadata(t) == expected
2880+
assert DictTransformer.extract_types(t) == expected
28812881

28822882

28832883
def test_DataclassTransformer_get_literal_type():

0 commit comments

Comments
 (0)