Skip to content

Commit 4028ecc

Browse files
committed
added test for is_container
1 parent ad11ea5 commit 4028ecc

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pydra/utils/tests/test_typing.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fileformats.generic import File
1212
from pydra.engine.lazy import LazyOutField
1313
from pydra.compose import workflow
14-
from pydra.utils.typing import TypeParser, MultiInputObj
14+
from pydra.utils.typing import TypeParser, MultiInputObj, is_container
1515
from fileformats.application import Json, Yaml, Xml
1616
from .utils import (
1717
GenericFuncTask,
@@ -866,6 +866,34 @@ def test_none_is_subclass2a():
866866
assert not TypeParser.is_subclass(None, int | float)
867867

868868

869+
@pytest.mark.parametrize(
870+
("type_",),
871+
[
872+
(str,),
873+
(ty.List[int],),
874+
(ty.Tuple[int, ...],),
875+
(ty.Dict[str, int],),
876+
(ty.Union[ty.List[int], ty.Tuple[int, ...]],),
877+
(ty.Union[ty.List[int], ty.Dict[str, int]],),
878+
(ty.Union[ty.List[int], ty.Tuple[int, ...], ty.Dict[str, int]],),
879+
],
880+
)
881+
def test_is_container(type_):
882+
assert is_container(type_)
883+
884+
885+
@pytest.mark.parametrize(
886+
("type_",),
887+
[
888+
(int,),
889+
(bool,),
890+
(ty.Union[bool, str],),
891+
],
892+
)
893+
def test_is_not_container(type_):
894+
assert not is_container(type_)
895+
896+
869897
@pytest.mark.skipif(
870898
sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9"
871899
)

pydra/utils/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,8 @@ def is_optional(type_: type) -> bool:
10731073
def is_container(type_: type) -> bool:
10741074
"""Check if the type is a container, i.e. a list, tuple, or MultiOutputObj"""
10751075
origin = ty.get_origin(type_)
1076+
if origin is ty.Union:
1077+
return all(is_container(a) for a in ty.get_args(type_))
10761078
tp = origin if origin else type_
10771079
return inspect.isclass(tp) and issubclass(tp, ty.Container)
10781080

0 commit comments

Comments
 (0)