Skip to content

Commit 1858668

Browse files
authored
Merge pull request #696 from tclose/subclass-permissive-typing
Permit superclass to subclass lazy typing
2 parents 1720ba6 + 27e7fb8 commit 1858668

File tree

6 files changed

+316
-62
lines changed

6 files changed

+316
-62
lines changed

pydra/engine/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def make_klass(spec):
263263
**kwargs,
264264
)
265265
checker_label = f"'{name}' field of {spec.name}"
266-
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
266+
type_checker = TypeParser[newfield.type](
267+
newfield.type, label=checker_label, superclass_auto_cast=True
268+
)
267269
if newfield.type in (MultiInputObj, MultiInputFile):
268270
converter = attr.converters.pipe(ensure_list, type_checker)
269271
elif newfield.type in (MultiOutputObj, MultiOutputFile):

pydra/engine/specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
445445
),
446446
):
447447
raise TypeError(
448-
f"Support for {fld.type} type, required for {fld.name} in {self}, "
448+
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
449449
"has not been implemented in collect_additional_output"
450450
)
451451
# assuming that field should have either default or metadata, but not both

pydra/engine/tests/test_node_task.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,7 @@ def test_task_init_3a(
133133

134134

135135
def test_task_init_4():
136-
"""task with interface and inputs. splitter set using split method"""
137-
nn = fun_addtwo(name="NA")
138-
nn.split(splitter="a", a=[3, 5])
139-
assert np.allclose(nn.inputs.a, [3, 5])
140-
141-
assert nn.state.splitter == "NA.a"
142-
assert nn.state.splitter_rpn == ["NA.a"]
143-
144-
nn.state.prepare_states(nn.inputs)
145-
assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
146-
assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}]
147-
148-
149-
def test_task_init_4a():
150-
"""task with a splitter and inputs set in the split method"""
136+
"""task with interface splitter and inputs set in the split method"""
151137
nn = fun_addtwo(name="NA")
152138
nn.split(splitter="a", a=[3, 5])
153139
assert np.allclose(nn.inputs.a, [3, 5])

pydra/utils/tests/test_typing.py

Lines changed: 170 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import itertools
3+
import sys
34
import typing as ty
45
from pathlib import Path
56
import tempfile
@@ -8,13 +9,16 @@
89
from ...engine.specs import File, LazyOutField
910
from ..typing import TypeParser
1011
from pydra import Workflow
11-
from fileformats.application import Json
12+
from fileformats.application import Json, Yaml, Xml
1213
from .utils import (
1314
generic_func_task,
1415
GenericShellTask,
1516
specific_func_task,
1617
SpecificShellTask,
18+
other_specific_func_task,
19+
OtherSpecificShellTask,
1720
MyFormatX,
21+
MyOtherFormatX,
1822
MyHeader,
1923
)
2024

@@ -152,8 +156,12 @@ def test_type_check_nested6():
152156

153157

154158
def test_type_check_nested7():
159+
TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int]))
160+
161+
162+
def test_type_check_nested7a():
155163
with pytest.raises(TypeError, match="Wrong number of type arguments"):
156-
TypeParser(ty.Tuple[float, float, float])(lz(ty.List[int]))
164+
TypeParser(ty.Tuple[float, float, float])(lz(ty.Tuple[int]))
157165

158166

159167
def test_type_check_nested8():
@@ -164,6 +172,18 @@ def test_type_check_nested8():
164172
)(lz(ty.List[float]))
165173

166174

175+
def test_type_check_permit_superclass():
176+
# Typical case as Json is subclass of File
177+
TypeParser(ty.List[File])(lz(ty.List[Json]))
178+
# Permissive super class, as File is superclass of Json
179+
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File]))
180+
with pytest.raises(TypeError, match="Cannot coerce"):
181+
TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File]))
182+
# Fails because Yaml is neither sub or super class of Json
183+
with pytest.raises(TypeError, match="Cannot coerce"):
184+
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml]))
185+
186+
167187
def test_type_check_fail1():
168188
with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"):
169189
TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float]))
@@ -490,14 +510,29 @@ def test_matches_type_tuple():
490510
assert not TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, int])
491511

492512

493-
def test_matches_type_tuple_ellipsis():
513+
def test_matches_type_tuple_ellipsis1():
494514
assert TypeParser.matches_type(ty.Tuple[int], ty.Tuple[int, ...])
515+
516+
517+
def test_matches_type_tuple_ellipsis2():
495518
assert TypeParser.matches_type(ty.Tuple[int, int], ty.Tuple[int, ...])
519+
520+
521+
def test_matches_type_tuple_ellipsis3():
496522
assert not TypeParser.matches_type(ty.Tuple[int, float], ty.Tuple[int, ...])
497-
assert not TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int])
523+
524+
525+
def test_matches_type_tuple_ellipsis4():
526+
assert TypeParser.matches_type(ty.Tuple[int, ...], ty.Tuple[int])
527+
528+
529+
def test_matches_type_tuple_ellipsis5():
498530
assert TypeParser.matches_type(
499531
ty.Tuple[int], ty.List[int], coercible=[(tuple, list)]
500532
)
533+
534+
535+
def test_matches_type_tuple_ellipsis6():
501536
assert TypeParser.matches_type(
502537
ty.Tuple[int, ...], ty.List[int], coercible=[(tuple, list)]
503538
)
@@ -538,7 +573,17 @@ def specific_task(request):
538573
assert False
539574

540575

541-
def test_typing_cast(tmp_path, generic_task, specific_task):
576+
@pytest.fixture(params=["func", "shell"])
577+
def other_specific_task(request):
578+
if request.param == "func":
579+
return other_specific_func_task
580+
elif request.param == "shell":
581+
return OtherSpecificShellTask
582+
else:
583+
assert False
584+
585+
586+
def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task):
542587
"""Check the casting of lazy fields and whether specific file-sets can be recovered
543588
from generic `File` classes"""
544589

@@ -562,33 +607,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
562607
)
563608
)
564609

610+
wf.add(
611+
specific_task(
612+
in_file=wf.generic.lzout.out,
613+
name="specific2",
614+
)
615+
)
616+
617+
wf.set_output(
618+
[
619+
("out_file", wf.specific2.lzout.out),
620+
]
621+
)
622+
623+
in_file = MyFormatX.sample()
624+
625+
result = wf(in_file=in_file, plugin="serial")
626+
627+
out_file: MyFormatX = result.output.out_file
628+
assert type(out_file) is MyFormatX
629+
assert out_file.parent != in_file.parent
630+
assert type(out_file.header) is MyHeader
631+
assert out_file.header.parent != in_file.header.parent
632+
633+
634+
def test_typing_cast(tmp_path, specific_task, other_specific_task):
635+
"""Check the casting of lazy fields and whether specific file-sets can be recovered
636+
from generic `File` classes"""
637+
638+
wf = Workflow(
639+
name="test",
640+
input_spec={"in_file": MyFormatX},
641+
output_spec={"out_file": MyFormatX},
642+
)
643+
644+
wf.add(
645+
specific_task(
646+
in_file=wf.lzin.in_file,
647+
name="entry",
648+
)
649+
)
650+
651+
with pytest.raises(TypeError, match="Cannot coerce"):
652+
# No cast of generic task output to MyFormatX
653+
wf.add( # Generic task
654+
other_specific_task(
655+
in_file=wf.entry.lzout.out,
656+
name="inner",
657+
)
658+
)
659+
660+
wf.add( # Generic task
661+
other_specific_task(
662+
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
663+
name="inner",
664+
)
665+
)
666+
565667
with pytest.raises(TypeError, match="Cannot coerce"):
566668
# No cast of generic task output to MyFormatX
567669
wf.add(
568670
specific_task(
569-
in_file=wf.generic.lzout.out,
570-
name="specific2",
671+
in_file=wf.inner.lzout.out,
672+
name="exit",
571673
)
572674
)
573675

574676
wf.add(
575677
specific_task(
576-
in_file=wf.generic.lzout.out.cast(MyFormatX),
577-
name="specific2",
678+
in_file=wf.inner.lzout.out.cast(MyFormatX),
679+
name="exit",
578680
)
579681
)
580682

581683
wf.set_output(
582684
[
583-
("out_file", wf.specific2.lzout.out),
685+
("out_file", wf.exit.lzout.out),
584686
]
585687
)
586688

587-
my_fspath = tmp_path / "in_file.my"
588-
hdr_fspath = tmp_path / "in_file.hdr"
589-
my_fspath.write_text("my-format")
590-
hdr_fspath.write_text("my-header")
591-
in_file = MyFormatX([my_fspath, hdr_fspath])
689+
in_file = MyFormatX.sample()
592690

593691
result = wf(in_file=in_file, plugin="serial")
594692

@@ -611,6 +709,63 @@ def test_type_is_subclass3():
611709
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])
612710

613711

712+
def test_union_is_subclass1():
713+
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])
714+
715+
716+
def test_union_is_subclass2():
717+
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])
718+
719+
720+
def test_union_is_subclass3():
721+
assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml])
722+
723+
724+
def test_union_is_subclass4():
725+
assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json)
726+
727+
728+
def test_generic_is_subclass1():
729+
assert TypeParser.is_subclass(ty.List[int], list)
730+
731+
732+
def test_generic_is_subclass2():
733+
assert not TypeParser.is_subclass(list, ty.List[int])
734+
735+
736+
def test_generic_is_subclass3():
737+
assert not TypeParser.is_subclass(ty.List[float], ty.List[int])
738+
739+
740+
@pytest.mark.skipif(
741+
sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9"
742+
)
743+
def test_generic_is_subclass4():
744+
class MyTuple(tuple):
745+
pass
746+
747+
class A:
748+
pass
749+
750+
class B(A):
751+
pass
752+
753+
assert TypeParser.is_subclass(MyTuple[A], ty.Tuple[A])
754+
assert TypeParser.is_subclass(ty.Tuple[B], ty.Tuple[A])
755+
assert TypeParser.is_subclass(MyTuple[B], ty.Tuple[A])
756+
assert not TypeParser.is_subclass(ty.Tuple[A], ty.Tuple[B])
757+
assert not TypeParser.is_subclass(ty.Tuple[A], MyTuple[A])
758+
assert not TypeParser.is_subclass(MyTuple[A], ty.Tuple[B])
759+
assert TypeParser.is_subclass(MyTuple[A, int], ty.Tuple[A, int])
760+
assert TypeParser.is_subclass(ty.Tuple[B, int], ty.Tuple[A, int])
761+
assert TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A, int])
762+
assert TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[int, A])
763+
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[int, A])
764+
assert not TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[A, int])
765+
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A])
766+
assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int])
767+
768+
614769
def test_type_is_instance1():
615770
assert TypeParser.is_instance(File, ty.Type[File])
616771

pydra/utils/tests/utils.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from fileformats.generic import File
2-
from fileformats.core.mixin import WithSeparateHeader
2+
from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber
33
from pydra import mark
44
from pydra.engine.task import ShellCommandTask
55
from pydra.engine import specs
66

77

8-
class MyFormat(File):
8+
class MyFormat(WithMagicNumber, File):
99
ext = ".my"
10+
magic_number = b"MYFORMAT"
1011

1112

1213
class MyHeader(File):
@@ -17,6 +18,12 @@ class MyFormatX(WithSeparateHeader, MyFormat):
1718
header_type = MyHeader
1819

1920

21+
class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File):
22+
magic_number = b"MYFORMAT"
23+
ext = ".my"
24+
header_type = MyHeader
25+
26+
2027
@mark.task
2128
def generic_func_task(in_file: File) -> File:
2229
return in_file
@@ -118,3 +125,57 @@ class SpecificShellTask(ShellCommandTask):
118125
input_spec = specific_shell_input_spec
119126
output_spec = specific_shelloutput_spec
120127
executable = "echo"
128+
129+
130+
@mark.task
131+
def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX:
132+
return in_file
133+
134+
135+
other_specific_shell_input_fields = [
136+
(
137+
"in_file",
138+
MyOtherFormatX,
139+
{
140+
"help_string": "the input file",
141+
"argstr": "",
142+
"copyfile": "copy",
143+
"sep": " ",
144+
},
145+
),
146+
(
147+
"out",
148+
str,
149+
{
150+
"help_string": "output file name",
151+
"argstr": "",
152+
"position": -1,
153+
"output_file_template": "{in_file}", # Pass through un-altered
154+
},
155+
),
156+
]
157+
158+
other_specific_shell_input_spec = specs.SpecInfo(
159+
name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,)
160+
)
161+
162+
other_specific_shell_output_fields = [
163+
(
164+
"out",
165+
MyOtherFormatX,
166+
{
167+
"help_string": "output file",
168+
},
169+
),
170+
]
171+
other_specific_shelloutput_spec = specs.SpecInfo(
172+
name="Output",
173+
fields=other_specific_shell_output_fields,
174+
bases=(specs.ShellOutSpec,),
175+
)
176+
177+
178+
class OtherSpecificShellTask(ShellCommandTask):
179+
input_spec = other_specific_shell_input_spec
180+
output_spec = other_specific_shelloutput_spec
181+
executable = "echo"

0 commit comments

Comments
 (0)