Skip to content

Commit f31f61a

Browse files
committed
debugged is_subclass so it works properly for union types
1 parent d1e008f commit f31f61a

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

pydra/utils/tests/test_typing.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,34 @@ def test_type_is_subclass3():
611611
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])
612612

613613

614-
def test_type_is_subclass4():
614+
def test_union_is_subclass1():
615615
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])
616616

617617

618-
def test_type_is_subclass5():
618+
def test_union_is_subclass2():
619619
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])
620620

621621

622+
def test_union_is_subclass3():
623+
assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml])
624+
625+
626+
def test_union_is_subclass4():
627+
assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json)
628+
629+
630+
def test_generic_is_subclass1():
631+
assert TypeParser.is_subclass(ty.List[int], list)
632+
633+
634+
def test_generic_is_subclass2():
635+
assert not TypeParser.is_subclass(list, ty.List[int])
636+
637+
638+
def test_generic_is_subclass3():
639+
assert not TypeParser.is_subclass(ty.List[float], ty.List[int])
640+
641+
622642
def test_type_is_instance1():
623643
assert TypeParser.is_instance(File, ty.Type[File])
624644

pydra/utils/typing.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def matches_type(
562562
def is_instance(
563563
cls,
564564
obj: object,
565-
candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]],
565+
candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]],
566566
) -> bool:
567567
"""Checks whether the object is an instance of cls or that cls is typing.Any,
568568
extending the built-in isinstance to check nested type args
@@ -574,7 +574,7 @@ def is_instance(
574574
candidates : type or ty.Iterable[type]
575575
the candidate types to check the object against
576576
"""
577-
if not isinstance(candidates, (tuple, list)):
577+
if not isinstance(candidates, ty.Sequence):
578578
candidates = [candidates]
579579
for candidate in candidates:
580580
if candidate is ty.Any:
@@ -600,7 +600,7 @@ def is_instance(
600600
def is_subclass(
601601
cls,
602602
klass: ty.Type[ty.Any],
603-
candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]],
603+
candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]],
604604
any_ok: bool = False,
605605
) -> bool:
606606
"""Checks whether the class a is either the same as b, a subclass of b or b is
@@ -617,16 +617,23 @@ def is_subclass(
617617
"""
618618
if not isinstance(candidates, ty.Sequence):
619619
candidates = [candidates]
620+
if ty.Any in candidates:
621+
return True
622+
if klass is ty.Any:
623+
return any_ok
624+
625+
origin = get_origin(klass)
626+
args = get_args(klass)
620627

621628
for candidate in candidates:
629+
candidate_origin = get_origin(candidate)
630+
candidate_args = get_args(candidate)
622631
# Handle ty.Type[*] types in klass and candidates
623-
if ty.get_origin(klass) is type and (
624-
candidate is type or ty.get_origin(candidate) is type
625-
):
632+
if origin is type and (candidate is type or candidate_origin is type):
626633
if candidate is type:
627634
return True
628-
return cls.is_subclass(ty.get_args(klass)[0], ty.get_args(candidate)[0])
629-
elif ty.get_origin(klass) is type or ty.get_origin(candidate) is type:
635+
return cls.is_subclass(args[0], candidate_args[0])
636+
elif origin is type or candidate_origin is type:
630637
return False
631638
if NO_GENERIC_ISSUBCLASS:
632639
if klass is type and candidate is not type:
@@ -636,27 +643,29 @@ def is_subclass(
636643
):
637644
return True
638645
else:
639-
if klass is ty.Any:
640-
if ty.Any in candidates: # type: ignore
641-
return True
642-
else:
643-
return any_ok
644-
origin = get_origin(klass)
645646
if origin is ty.Union:
646-
args = get_args(klass)
647-
if get_origin(candidate) is ty.Union:
648-
candidate_args = get_args(candidate)
649-
else:
650-
candidate_args = [candidate]
651-
return all(
652-
any(cls.is_subclass(a, c) for c in candidate_args) for a in args
647+
union_args = (
648+
candidate_args if candidate_origin is ty.Union else (candidate,)
653649
)
654-
if origin is not None:
655-
klass = origin
656-
if klass is candidate or candidate is ty.Any:
657-
return True
658-
if issubclass(klass, candidate):
659-
return True
650+
matches = all(
651+
any(cls.is_subclass(a, c) for c in union_args) for a in args
652+
)
653+
if matches:
654+
return True
655+
else:
656+
if candidate_args and candidate_origin is not ty.Union:
657+
if (
658+
origin
659+
and issubclass(origin, candidate_origin) # type: ignore[arg-type]
660+
and len(args) == len(candidate_args)
661+
and all(
662+
issubclass(a, c) for a, c in zip(args, candidate_args)
663+
)
664+
):
665+
return True
666+
else:
667+
if issubclass(origin if origin else klass, candidate):
668+
return True
660669
return False
661670

662671
@classmethod

0 commit comments

Comments
 (0)