Skip to content
30 changes: 15 additions & 15 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,19 @@ def _type_essence(x: LiteralType) -> LiteralType:


def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool:
if upstream.union_type is not None:
# for each upstream variant, there must be a compatible type downstream
for v in upstream.union_type.variants:
if not _are_types_castable(v, downstream):
return False
return True

if downstream.union_type is not None:
# there must be a compatible downstream type
for v in downstream.union_type.variants:
if _are_types_castable(upstream, v):
return True

if upstream.collection_type is not None:
if downstream.collection_type is None:
return False
Expand Down Expand Up @@ -1814,19 +1827,6 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool:

return True

if upstream.union_type is not None:
# for each upstream variant, there must be a compatible type downstream
for v in upstream.union_type.variants:
if not _are_types_castable(v, downstream):
return False
return True

if downstream.union_type is not None:
# there must be a compatible downstream type
for v in downstream.union_type.variants:
if _are_types_castable(upstream, v):
return True

if upstream.enum_type is not None:
# enums are castable to string
if downstream.simple == SimpleType.STRING:
Expand Down Expand Up @@ -2113,7 +2113,7 @@ async def dict_to_generic_literal(
),
metadata={"format": "pickle"},
)
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\nError Message: {e}")

@staticmethod
async def dict_to_binary_literal(
Expand All @@ -2139,7 +2139,7 @@ async def dict_to_binary_literal(
),
metadata={"format": "pickle"},
)
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\nError Message: {e}")

@staticmethod
def is_pickle(python_type: Type[dict]) -> bool:
Expand Down
Loading