Skip to content

Commit d1e008f

Browse files
committed
added subclass tests for unions
1 parent 1720ba6 commit d1e008f

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

pydra/engine/specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
445445
),
446446
):
447447
raise TypeError(
448-
f"Support for {fld.type} type, required for {fld.name} in {self}, "
448+
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
449449
"has not been implemented in collect_additional_output"
450450
)
451451
# assuming that field should have either default or metadata, but not both

pydra/utils/tests/test_typing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ...engine.specs import File, LazyOutField
99
from ..typing import TypeParser
1010
from pydra import Workflow
11-
from fileformats.application import Json
11+
from fileformats.application import Json, Yaml, Xml
1212
from .utils import (
1313
generic_func_task,
1414
GenericShellTask,
@@ -611,6 +611,14 @@ def test_type_is_subclass3():
611611
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])
612612

613613

614+
def test_type_is_subclass4():
615+
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])
616+
617+
618+
def test_type_is_subclass5():
619+
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])
620+
621+
614622
def test_type_is_instance1():
615623
assert TypeParser.is_instance(File, ty.Type[File])
616624

pydra/utils/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def is_subclass(
649649
else:
650650
candidate_args = [candidate]
651651
return all(
652-
any(cls.is_subclass(a, c) for a in args) for c in candidate_args
652+
any(cls.is_subclass(a, c) for c in candidate_args) for a in args
653653
)
654654
if origin is not None:
655655
klass = origin

0 commit comments

Comments
 (0)