diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 24a78f184b..3b6ac4c88e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -13,6 +13,7 @@ import sys import textwrap import threading +import types import typing from abc import ABC, abstractmethod from collections import OrderedDict @@ -531,10 +532,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): # However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo). # Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type. - expected_fields_dict = {} - - for f in dataclasses.fields(expected_type): - expected_fields_dict[f.name] = f.type + expected_fields_dict = typing.get_type_hints(expected_type) if isinstance(v, dict): original_dict = v @@ -567,9 +565,19 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): else: expected_type = expected_fields_dict[k] original_type = type(v) + is_optional = False if UnionTransformer.is_optional_type(expected_type): + is_optional = True expected_type = UnionTransformer.get_sub_type_in_optional(expected_type) - if original_type != expected_type: + + if is_optional and original_type is type(None): + pass + elif UnionTransformer.is_union(expected_type) and UnionTransformer.in_union( + original_type, expected_type + ): + pass + + elif original_type != expected_type: raise TypeTransformerFailedError( f"Type of Val '{original_type}' is not an instance of {expected_type}" ) @@ -1949,6 +1957,14 @@ class UnionTransformer(AsyncTypeTransformer[T]): def __init__(self): super().__init__("Typed Union", typing.Union) + @staticmethod + def is_union(t: Type[Any] | types.UnionType) -> bool: + return _is_union_type(t) + + @staticmethod + def in_union(t: Type[Any], union: types.UnionType) -> bool: + return t in typing.get_args(union) + @staticmethod def is_optional_type(t: Type) -> bool: return _is_union_type(t) and type(None) in get_args(t)