Skip to content

Commit 3d543ed

Browse files
committed
fixes handling of union types
In the dataclasses conversion code type that is a member of a union was not properly checked for if it was a member and so there would always be an error. For instance `FlyteFile.path` is `Union[str,Pathlike]` and so `str != Union[str,Pathlike]`. This patch adds support for checking that a type is part of a union and a satisfactory type. Signed-off-by: Samuel Lotz <[email protected]>
1 parent 1a25939 commit 3d543ed

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

flytekit/core/type_engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import textwrap
1515
import threading
16+
import types
1617
import typing
1718
from abc import ABC, abstractmethod
1819
from collections import OrderedDict
@@ -565,7 +566,13 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
565566
original_type = type(v)
566567
if UnionTransformer.is_optional_type(expected_type):
567568
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)
568-
if original_type != expected_type:
569+
570+
if UnionTransformer.is_union(expected_type) and UnionTransformer.in_union(
571+
original_type, expected_type
572+
):
573+
pass
574+
575+
elif original_type != expected_type:
569576
raise TypeTransformerFailedError(
570577
f"Type of Val '{original_type}' is not an instance of {expected_type}"
571578
)
@@ -1836,6 +1843,14 @@ class UnionTransformer(AsyncTypeTransformer[T]):
18361843
def __init__(self):
18371844
super().__init__("Typed Union", typing.Union)
18381845

1846+
@staticmethod
1847+
def is_union(t: Type[Any] | types.UnionType) -> bool:
1848+
return _is_union_type(t)
1849+
1850+
@staticmethod
1851+
def in_union(t: Type[Any], union: types.UnionType) -> bool:
1852+
return t in typing.get_args(union)
1853+
18391854
@staticmethod
18401855
def is_optional_type(t: Type) -> bool:
18411856
return _is_union_type(t) and type(None) in get_args(t)

0 commit comments

Comments
 (0)