Skip to content

Commit 0cafc6e

Browse files
committed
added tests for explict and auto-superclass casting
1 parent 0ccc223 commit 0cafc6e

File tree

4 files changed

+182
-22
lines changed

4 files changed

+182
-22
lines changed

pydra/engine/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def make_klass(spec):
264264
)
265265
checker_label = f"'{name}' field of {spec.name}"
266266
type_checker = TypeParser[newfield.type](
267-
newfield.type, label=checker_label, allow_lazy_super=True
267+
newfield.type, label=checker_label, superclass_auto_cast=True
268268
)
269269
if newfield.type in (MultiInputObj, MultiInputFile):
270270
converter = attr.converters.pipe(ensure_list, type_checker)

pydra/utils/tests/test_typing.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
GenericShellTask,
1515
specific_func_task,
1616
SpecificShellTask,
17+
other_specific_func_task,
18+
OtherSpecificShellTask,
1719
MyFormatX,
20+
MyOtherFormatX,
1821
MyHeader,
1922
)
2023

@@ -168,12 +171,12 @@ def test_type_check_permit_superclass():
168171
# Typical case as Json is subclass of File
169172
TypeParser(ty.List[File])(lz(ty.List[Json]))
170173
# Permissive super class, as File is superclass of Json
171-
TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[File]))
174+
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File]))
172175
with pytest.raises(TypeError, match="Cannot coerce"):
173-
TypeParser(ty.List[Json], allow_lazy_super=False)(lz(ty.List[File]))
176+
TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File]))
174177
# Fails because Yaml is neither sub or super class of Json
175178
with pytest.raises(TypeError, match="Cannot coerce"):
176-
TypeParser(ty.List[Json], allow_lazy_super=True)(lz(ty.List[Yaml]))
179+
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml]))
177180

178181

179182
def test_type_check_fail1():
@@ -550,7 +553,17 @@ def specific_task(request):
550553
assert False
551554

552555

553-
def test_typing_cast(tmp_path, generic_task, specific_task):
556+
@pytest.fixture(params=["func", "shell"])
557+
def other_specific_task(request):
558+
if request.param == "func":
559+
return other_specific_func_task
560+
elif request.param == "shell":
561+
return OtherSpecificShellTask
562+
else:
563+
assert False
564+
565+
566+
def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task):
554567
"""Check the casting of lazy fields and whether specific file-sets can be recovered
555568
from generic `File` classes"""
556569

@@ -574,33 +587,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
574587
)
575588
)
576589

590+
wf.add(
591+
specific_task(
592+
in_file=wf.generic.lzout.out,
593+
name="specific2",
594+
)
595+
)
596+
597+
wf.set_output(
598+
[
599+
("out_file", wf.specific2.lzout.out),
600+
]
601+
)
602+
603+
in_file = MyFormatX.sample()
604+
605+
result = wf(in_file=in_file, plugin="serial")
606+
607+
out_file: MyFormatX = result.output.out_file
608+
assert type(out_file) is MyFormatX
609+
assert out_file.parent != in_file.parent
610+
assert type(out_file.header) is MyHeader
611+
assert out_file.header.parent != in_file.header.parent
612+
613+
614+
def test_typing_cast(tmp_path, specific_task, other_specific_task):
615+
"""Check the casting of lazy fields and whether specific file-sets can be recovered
616+
from generic `File` classes"""
617+
618+
wf = Workflow(
619+
name="test",
620+
input_spec={"in_file": MyFormatX},
621+
output_spec={"out_file": MyFormatX},
622+
)
623+
624+
wf.add(
625+
specific_task(
626+
in_file=wf.lzin.in_file,
627+
name="entry",
628+
)
629+
)
630+
631+
with pytest.raises(TypeError, match="Cannot coerce"):
632+
# No cast of generic task output to MyFormatX
633+
wf.add( # Generic task
634+
other_specific_task(
635+
in_file=wf.entry.lzout.out,
636+
name="inner",
637+
)
638+
)
639+
640+
wf.add( # Generic task
641+
other_specific_task(
642+
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
643+
name="inner",
644+
)
645+
)
646+
577647
with pytest.raises(TypeError, match="Cannot coerce"):
578648
# No cast of generic task output to MyFormatX
579649
wf.add(
580650
specific_task(
581-
in_file=wf.generic.lzout.out,
582-
name="specific2",
651+
in_file=wf.inner.lzout.out,
652+
name="exit",
583653
)
584654
)
585655

586656
wf.add(
587657
specific_task(
588-
in_file=wf.generic.lzout.out.cast(MyFormatX),
589-
name="specific2",
658+
in_file=wf.inner.lzout.out.cast(MyFormatX),
659+
name="exit",
590660
)
591661
)
592662

593663
wf.set_output(
594664
[
595-
("out_file", wf.specific2.lzout.out),
665+
("out_file", wf.exit.lzout.out),
596666
]
597667
)
598668

599-
my_fspath = tmp_path / "in_file.my"
600-
hdr_fspath = tmp_path / "in_file.hdr"
601-
my_fspath.write_text("my-format")
602-
hdr_fspath.write_text("my-header")
603-
in_file = MyFormatX([my_fspath, hdr_fspath])
669+
in_file = MyFormatX.sample()
604670

605671
result = wf(in_file=in_file, plugin="serial")
606672

pydra/utils/tests/utils.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from pathlib import Path
2+
import typing as ty
13
from fileformats.generic import File
2-
from fileformats.core.mixin import WithSeparateHeader
4+
from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber
35
from pydra import mark
46
from pydra.engine.task import ShellCommandTask
57
from pydra.engine import specs
68

79

8-
class MyFormat(File):
10+
class MyFormat(WithMagicNumber, File):
911
ext = ".my"
12+
magic_number = b"MYFORMAT"
1013

1114

1215
class MyHeader(File):
@@ -17,6 +20,34 @@ class MyFormatX(WithSeparateHeader, MyFormat):
1720
header_type = MyHeader
1821

1922

23+
class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File):
24+
magic_number = b"MYFORMAT"
25+
ext = ".my"
26+
header_type = MyHeader
27+
28+
29+
@File.generate_sample_data.register
30+
def my_format_x_generate_sample_data(
31+
my_format_x: MyFormatX, dest_dir: Path
32+
) -> ty.List[Path]:
33+
fspath = dest_dir / "file.my"
34+
with open(fspath, "wb") as f:
35+
f.write(b"MYFORMAT\nsome data goes here")
36+
header_fspath = dest_dir / "file.hdr"
37+
header_fspath.write_text("a: 1\nb: 2\nc: 3\n")
38+
return [fspath, header_fspath]
39+
40+
41+
@File.generate_sample_data.register
42+
def my_other_format_generate_sample_data(
43+
my_other_format: MyOtherFormatX, dest_dir: Path
44+
) -> ty.List[Path]:
45+
fspath = dest_dir / "file.my"
46+
with open(fspath, "wb") as f:
47+
f.write(b"MYFORMAT\nsome data goes here")
48+
return [fspath]
49+
50+
2051
@mark.task
2152
def generic_func_task(in_file: File) -> File:
2253
return in_file
@@ -118,3 +149,57 @@ class SpecificShellTask(ShellCommandTask):
118149
input_spec = specific_shell_input_spec
119150
output_spec = specific_shelloutput_spec
120151
executable = "echo"
152+
153+
154+
@mark.task
155+
def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX:
156+
return in_file
157+
158+
159+
other_specific_shell_input_fields = [
160+
(
161+
"in_file",
162+
MyOtherFormatX,
163+
{
164+
"help_string": "the input file",
165+
"argstr": "",
166+
"copyfile": "copy",
167+
"sep": " ",
168+
},
169+
),
170+
(
171+
"out",
172+
str,
173+
{
174+
"help_string": "output file name",
175+
"argstr": "",
176+
"position": -1,
177+
"output_file_template": "{in_file}", # Pass through un-altered
178+
},
179+
),
180+
]
181+
182+
other_specific_shell_input_spec = specs.SpecInfo(
183+
name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,)
184+
)
185+
186+
other_specific_shell_output_fields = [
187+
(
188+
"out",
189+
MyOtherFormatX,
190+
{
191+
"help_string": "output file",
192+
},
193+
),
194+
]
195+
other_specific_shelloutput_spec = specs.SpecInfo(
196+
name="Output",
197+
fields=other_specific_shell_output_fields,
198+
bases=(specs.ShellOutSpec,),
199+
)
200+
201+
202+
class OtherSpecificShellTask(ShellCommandTask):
203+
input_spec = other_specific_shell_input_spec
204+
output_spec = other_specific_shelloutput_spec
205+
executable = "echo"

pydra/utils/typing.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class TypeParser(ty.Generic[T]):
5858
the tree of more complex nested container types. Overrides 'coercible' to enable
5959
you to carve out exceptions, such as TypeParser(list, coercible=[(ty.Iterable, list)],
6060
not_coercible=[(str, list)])
61-
allow_lazy_super : bool
61+
superclass_auto_cast : bool
6262
Allow lazy fields to pass the type check if their types are superclasses of the
6363
specified pattern (instead of matching or being subclasses of the pattern)
6464
label : str
@@ -69,7 +69,7 @@ class TypeParser(ty.Generic[T]):
6969
tp: ty.Type[T]
7070
coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]]
7171
not_coercible: ty.List[ty.Tuple[TypeOrAny, TypeOrAny]]
72-
allow_lazy_super: bool
72+
superclass_auto_cast: bool
7373
label: str
7474

7575
COERCIBLE_DEFAULT: ty.Tuple[ty.Tuple[type, type], ...] = (
@@ -113,7 +113,7 @@ def __init__(
113113
not_coercible: ty.Optional[
114114
ty.Iterable[ty.Tuple[TypeOrAny, TypeOrAny]]
115115
] = NOT_COERCIBLE_DEFAULT,
116-
allow_lazy_super: bool = False,
116+
superclass_auto_cast: bool = False,
117117
label: str = "",
118118
):
119119
def expand_pattern(t):
@@ -142,7 +142,7 @@ def expand_pattern(t):
142142
)
143143
self.not_coercible = list(not_coercible) if not_coercible is not None else []
144144
self.pattern = expand_pattern(tp)
145-
self.allow_lazy_super = allow_lazy_super
145+
self.superclass_auto_cast = superclass_auto_cast
146146

147147
def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]:
148148
"""Attempts to coerce the object to the specified type, unless the value is
@@ -172,7 +172,7 @@ def __call__(self, obj: ty.Any) -> ty.Union[T, LazyField[T]]:
172172
try:
173173
self.check_type(obj.type)
174174
except TypeError as e:
175-
if self.allow_lazy_super:
175+
if self.superclass_auto_cast:
176176
try:
177177
# Check whether the type of the lazy field isn't a superclass of
178178
# the type to check against, and if so, allow it due to permissive
@@ -492,8 +492,17 @@ def check_coercible(
492492
explicit inclusions and exclusions set in the `coercible` and `not_coercible`
493493
member attrs
494494
"""
495+
# Short-circuit the basic cases where the source and target are the same
495496
if source is target:
496497
return
498+
if self.superclass_auto_cast and self.is_subclass(target, type(source)):
499+
logger.info(
500+
"Attempting to coerce %s into %s due to super-to-sub class coercion "
501+
"being permitted",
502+
source,
503+
target,
504+
)
505+
return
497506
source_origin = get_origin(source)
498507
if source_origin is not None:
499508
source = source_origin

0 commit comments

Comments
 (0)