Skip to content

Commit 29e3af4

Browse files
committed
fixed up handling of type types (e.g. ty.Type[*])
1 parent 32e10d3 commit 29e3af4

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

pydra/utils/tests/test_typing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,27 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
597597
assert out_file.parent != in_file.parent
598598
assert type(out_file.header) is MyHeader
599599
assert out_file.header.parent != in_file.header.parent
600+
601+
602+
def test_type_is_subclass1():
603+
assert TypeParser.is_subclass(ty.Type[File], type)
604+
605+
606+
def test_type_is_subclass2():
607+
assert not TypeParser.is_subclass(ty.Type[File], ty.Type[Json])
608+
609+
610+
def test_type_is_subclass3():
611+
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])
612+
613+
614+
def test_type_is_instance1():
615+
assert TypeParser.is_instance(File, ty.Type[File])
616+
617+
618+
def test_type_is_instance2():
619+
assert not TypeParser.is_instance(File, ty.Type[Json])
620+
621+
622+
def test_type_is_instance3():
623+
assert TypeParser.is_instance(Json, ty.Type[File])

pydra/utils/typing.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,9 +547,11 @@ def matches_type(
547547
return False
548548
return True
549549

550-
@staticmethod
550+
@classmethod
551551
def is_instance(
552-
obj: object, candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]]
552+
cls,
553+
obj: object,
554+
candidates: ty.Union[ty.Type[ty.Any], ty.Iterable[ty.Type[ty.Any]]],
553555
) -> bool:
554556
"""Checks whether the object is an instance of cls or that cls is typing.Any,
555557
extending the built-in isinstance to check nested type args
@@ -566,9 +568,14 @@ def is_instance(
566568
for candidate in candidates:
567569
if candidate is ty.Any:
568570
return True
571+
# Handle ty.Type[*] candidates
572+
if ty.get_origin(candidate) is type:
573+
return inspect.isclass(obj) and cls.is_subclass(
574+
obj, ty.get_args(candidate)[0]
575+
)
569576
if NO_GENERIC_ISSUBCLASS:
570-
if candidate is type and inspect.isclass(obj):
571-
return True
577+
if inspect.isclass(obj):
578+
return candidate is type
572579
if issubtype(type(obj), candidate) or (
573580
type(obj) is dict and candidate is ty.Mapping
574581
):
@@ -597,10 +604,19 @@ def is_subclass(
597604
any_ok : bool
598605
whether klass=typing.Any should return True or False
599606
"""
600-
if not isinstance(candidates, ty.Iterable):
607+
if not isinstance(candidates, ty.Sequence):
601608
candidates = [candidates]
602609

603610
for candidate in candidates:
611+
# Handle ty.Type[*] types in klass and candidates
612+
if ty.get_origin(klass) is type and (
613+
candidate is type or ty.get_origin(candidate) is type
614+
):
615+
if candidate is type:
616+
return True
617+
return cls.is_subclass(ty.get_args(klass)[0], ty.get_args(candidate)[0])
618+
elif ty.get_origin(klass) is type or ty.get_origin(candidate) is type:
619+
return False
604620
if NO_GENERIC_ISSUBCLASS:
605621
if klass is type and candidate is not type:
606622
return False

0 commit comments

Comments
 (0)