Skip to content

Commit 0eed44a

Browse files
tcloseeffigies
authored andcommitted
added unittests for multi_input_obj coercion
1 parent 5a7955e commit 0eed44a

File tree

2 files changed

+103
-19
lines changed

2 files changed

+103
-19
lines changed

pydra/utils/tests/test_typing.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tempfile
88
import pytest
99
from pydra import mark
10-
from ...engine.specs import File, LazyOutField
10+
from ...engine.specs import File, LazyOutField, MultiInputObj
1111
from ..typing import TypeParser
1212
from pydra import Workflow
1313
from fileformats.application import Json, Yaml, Xml
@@ -249,7 +249,7 @@ def test_type_check_fail3():
249249
def test_type_check_fail4():
250250
with pytest.raises(TypeError) as exc_info:
251251
TypeParser(ty.Sequence)(lz(ty.Dict[str, int]))
252-
assert exc_info_matches(exc_info, "Cannot coerce <class 'dict'> into")
252+
assert exc_info_matches(exc_info, "Cannot coerce .*(d|D)ict.* into")
253253

254254

255255
def test_type_check_fail5():
@@ -1043,3 +1043,63 @@ def test_type_is_instance11():
10431043
@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10")
10441044
def test_type_is_instance11a():
10451045
assert not TypeParser.is_instance(None, int | str)
1046+
1047+
1048+
def test_multi_input_obj_coerce1():
1049+
assert TypeParser(MultiInputObj[str])("a") == ["a"]
1050+
1051+
1052+
def test_multi_input_obj_coerce2():
1053+
assert TypeParser(MultiInputObj[str])(["a"]) == ["a"]
1054+
1055+
1056+
def test_multi_input_obj_coerce3():
1057+
assert TypeParser(MultiInputObj[ty.List[str]])(["a"]) == [["a"]]
1058+
1059+
1060+
def test_multi_input_obj_coerce3a():
1061+
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(["a"]) == [["a"]]
1062+
1063+
1064+
def test_multi_input_obj_coerce3b():
1065+
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([["a"]]) == [["a"]]
1066+
1067+
1068+
def test_multi_input_obj_coerce4():
1069+
assert TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([1]) == [1]
1070+
1071+
1072+
def test_multi_input_obj_coerce4a():
1073+
with pytest.raises(TypeError):
1074+
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])([[1]])
1075+
1076+
1077+
def test_multi_input_obj_check_type1():
1078+
TypeParser(MultiInputObj[str])(lz(str))
1079+
1080+
1081+
def test_multi_input_obj_check_type2():
1082+
TypeParser(MultiInputObj[str])(lz(ty.List[str]))
1083+
1084+
1085+
def test_multi_input_obj_check_type3():
1086+
TypeParser(MultiInputObj[ty.List[str]])(lz(ty.List[str]))
1087+
1088+
1089+
def test_multi_input_obj_check_type3a():
1090+
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[str]))
1091+
1092+
1093+
def test_multi_input_obj_check_type3b():
1094+
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[ty.List[str]]))
1095+
1096+
1097+
def test_multi_input_obj_check_type4():
1098+
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(lz(ty.List[int]))
1099+
1100+
1101+
def test_multi_input_obj_check_type4a():
1102+
with pytest.raises(TypeError):
1103+
TypeParser(MultiInputObj[ty.Union[int, ty.List[str]]])(
1104+
lz(ty.List[ty.List[int]])
1105+
)

pydra/utils/typing.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
from pathlib import Path
44
import os
5+
from copy import copy
56
import sys
67
import types
78
import typing as ty
@@ -13,6 +14,7 @@
1314
MultiInputObj,
1415
MultiOutputObj,
1516
)
17+
from ..utils import add_exc_note
1618
from fileformats import field
1719

1820
try:
@@ -366,18 +368,26 @@ def coerce_obj(obj, type_):
366368
f"Cannot coerce {obj!r} into {type_}{msg}{self.label_str}"
367369
) from e
368370

369-
# Special handling for MultiInputObjects (which are annoying)
370-
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
371-
try:
372-
self.check_coercible(object_, self.pattern[1][0])
373-
except TypeError:
374-
pass
371+
try:
372+
return expand_and_coerce(object_, self.pattern)
373+
except TypeError as e:
374+
# Special handling for MultiInputObjects (which are annoying)
375+
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
376+
# Attempt to coerce the object into arg type of the MultiInputObj first,
377+
# and if that fails, try to coerce it into a list of the arg type
378+
inner_type_parser = copy(self)
379+
inner_type_parser.pattern = self.pattern[1][0]
380+
try:
381+
return [inner_type_parser.coerce(object_)]
382+
except TypeError:
383+
add_exc_note(
384+
e,
385+
"Also failed to coerce to the arg-type of the MultiInputObj "
386+
f"({self.pattern[1][0]})",
387+
)
388+
raise e
375389
else:
376-
obj = [object_]
377-
else:
378-
obj = object_
379-
380-
return expand_and_coerce(obj, self.pattern)
390+
raise e
381391

382392
def check_type(self, type_: ty.Type[ty.Any]):
383393
"""Checks the given type to see whether it matches or is a subtype of the
@@ -537,12 +547,26 @@ def check_sequence(tp_args, pattern_args):
537547
for arg in tp_args:
538548
expand_and_check(arg, pattern_args[0])
539549

540-
# Special handling for MultiInputObjects (which are annoying)
541-
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
542-
pattern = (ty.Union, [self.pattern[1][0], (ty.List, self.pattern[1])])
543-
else:
544-
pattern = self.pattern
545-
return expand_and_check(type_, pattern)
550+
try:
551+
return expand_and_check(type_, self.pattern)
552+
except TypeError as e:
553+
# Special handling for MultiInputObjects (which are annoying)
554+
if isinstance(self.pattern, tuple) and self.pattern[0] == MultiInputObj:
555+
# Attempt to coerce the object into arg type of the MultiInputObj first,
556+
# and if that fails, try to coerce it into a list of the arg type
557+
inner_type_parser = copy(self)
558+
inner_type_parser.pattern = self.pattern[1][0]
559+
try:
560+
inner_type_parser.check_type(type_)
561+
except TypeError:
562+
add_exc_note(
563+
e,
564+
"Also failed to coerce to the arg-type of the MultiInputObj "
565+
f"({self.pattern[1][0]})",
566+
)
567+
raise e
568+
else:
569+
raise e
546570

547571
def check_coercible(self, source: ty.Any, target: ty.Union[type, ty.Any]):
548572
"""Checks whether the source object is coercible to the target type given the coercion

0 commit comments

Comments
 (0)