Skip to content

Commit 31aea01

Browse files
authored
Merge pull request #687 from tclose/typing-bugfixes
Typing bugfixes
2 parents 428cf04 + 103cefc commit 31aea01

File tree

6 files changed

+230
-68
lines changed

6 files changed

+230
-68
lines changed

pydra/engine/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def make_klass(spec):
261261
type=tp,
262262
**kwargs,
263263
)
264-
type_checker = TypeParser[newfield.type](newfield.type)
264+
checker_label = f"'{name}' field of {spec.name}"
265+
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
265266
if newfield.type in (MultiInputObj, MultiInputFile):
266267
converter = attr.converters.pipe(ensure_list, type_checker)
267268
elif newfield.type in (MultiOutputObj, MultiOutputFile):
@@ -652,7 +653,11 @@ def argstr_formatting(argstr, inputs, value_updates=None):
652653
for fld in inp_fields:
653654
fld_name = fld[1:-1] # extracting the name form {field_name}
654655
fld_value = inputs_dict[fld_name]
655-
if fld_value is attr.NOTHING:
656+
fld_attr = getattr(attrs.fields(type(inputs)), fld_name)
657+
if fld_value is attr.NOTHING or (
658+
fld_value is False
659+
and TypeParser.matches_type(fld_attr.type, ty.Union[Path, bool])
660+
):
656661
# if value is NOTHING, nothing should be added to the command
657662
val_dict[fld_name] = ""
658663
else:

pydra/engine/helpers_file.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def template_update(inputs, output_dir, state_ind=None, map_copyfiles=None):
120120
field
121121
for field in attr_fields(inputs)
122122
if field.metadata.get("output_file_template")
123+
and getattr(inputs, field.name) is not False
123124
and all(
124125
getattr(inputs, required_field) is not attr.NOTHING
125126
for required_field in field.metadata.get("requires", ())
@@ -150,25 +151,19 @@ def template_update_single(
150151
# if input_dict_st with state specific value is not available,
151152
# the dictionary will be created from inputs object
152153
from ..utils.typing import TypeParser # noqa
153-
from pydra.engine.specs import LazyField
154-
155-
VALID_TYPES = (str, ty.Union[str, bool], Path, ty.Union[Path, bool], LazyField)
154+
from pydra.engine.specs import LazyField, OUTPUT_TEMPLATE_TYPES
156155

157156
if inputs_dict_st is None:
158157
inputs_dict_st = attr.asdict(inputs, recurse=False)
159158

160159
if spec_type == "input":
161160
inp_val_set = inputs_dict_st[field.name]
162-
if inp_val_set is not attr.NOTHING and not TypeParser.is_instance(
163-
inp_val_set, VALID_TYPES
164-
):
165-
raise TypeError(
166-
f"'{field.name}' field has to be a Path instance or a bool, but {inp_val_set} set"
167-
)
168161
if isinstance(inp_val_set, bool) and field.type in (Path, str):
169162
raise TypeError(
170163
f"type of '{field.name}' is Path, consider using Union[Path, bool]"
171164
)
165+
if inp_val_set is not attr.NOTHING and not isinstance(inp_val_set, LazyField):
166+
inp_val_set = TypeParser(ty.Union[OUTPUT_TEMPLATE_TYPES])(inp_val_set)
172167
elif spec_type == "output":
173168
if not TypeParser.contains_type(FileSet, field.type):
174169
raise TypeError(
@@ -178,22 +173,23 @@ def template_update_single(
178173
else:
179174
raise TypeError(f"spec_type can be input or output, but {spec_type} provided")
180175
# for inputs that the value is set (so the template is ignored)
181-
if spec_type == "input" and isinstance(inputs_dict_st[field.name], (str, Path)):
182-
return inputs_dict_st[field.name]
183-
elif spec_type == "input" and inputs_dict_st[field.name] is False:
184-
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
185-
return attr.NOTHING
186-
else: # inputs_dict[field.name] is True or spec_type is output
187-
value = _template_formatting(field, inputs, inputs_dict_st)
188-
# changing path so it is in the output_dir
189-
if output_dir and value is not attr.NOTHING:
190-
# should be converted to str, it is also used for input fields that should be str
191-
if type(value) is list:
192-
return [str(output_dir / Path(val).name) for val in value]
193-
else:
194-
return str(output_dir / Path(value).name)
195-
else:
176+
if spec_type == "input":
177+
if isinstance(inp_val_set, (Path, list)):
178+
return inp_val_set
179+
if inp_val_set is False:
180+
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
196181
return attr.NOTHING
182+
# inputs_dict[field.name] is True or spec_type is output
183+
value = _template_formatting(field, inputs, inputs_dict_st)
184+
# changing path so it is in the output_dir
185+
if output_dir and value is not attr.NOTHING:
186+
# should be converted to str, it is also used for input fields that should be str
187+
if type(value) is list:
188+
return [str(output_dir / Path(val).name) for val in value]
189+
else:
190+
return str(output_dir / Path(value).name)
191+
else:
192+
return attr.NOTHING
197193

198194

199195
def _template_formatting(field, inputs, inputs_dict_st):
@@ -204,16 +200,27 @@ def _template_formatting(field, inputs, inputs_dict_st):
204200
Allowing for multiple input values used in the template as longs as
205201
there is no more than one file (i.e. File, PathLike or string with extensions)
206202
"""
207-
from .specs import MultiInputObj, MultiOutputFile
208-
209203
# if a template is a function it has to be run first with the inputs as the only arg
210204
template = field.metadata["output_file_template"]
211205
if callable(template):
212206
template = template(inputs)
213207

214208
# as default, we assume that keep_extension is True
215-
keep_extension = field.metadata.get("keep_extension", True)
209+
if isinstance(template, (tuple, list)):
210+
formatted = [
211+
_string_template_formatting(field, t, inputs, inputs_dict_st)
212+
for t in template
213+
]
214+
else:
215+
assert isinstance(template, str)
216+
formatted = _string_template_formatting(field, template, inputs, inputs_dict_st)
217+
return formatted
218+
216219

220+
def _string_template_formatting(field, template, inputs, inputs_dict_st):
221+
from .specs import MultiInputObj, MultiOutputFile
222+
223+
keep_extension = field.metadata.get("keep_extension", True)
217224
inp_fields = re.findall(r"{\w+}", template)
218225
inp_fields_fl = re.findall(r"{\w+:[0-9.]+f}", template)
219226
inp_fields += [re.sub(":[0-9.]+f", "", el) for el in inp_fields_fl]

pydra/engine/specs.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ class MultiOutputType:
4646
MultiOutputObj = ty.Union[list, object, MultiOutputType]
4747
MultiOutputFile = ty.Union[File, ty.List[File], MultiOutputType]
4848

49+
OUTPUT_TEMPLATE_TYPES = (
50+
Path,
51+
ty.List[Path],
52+
ty.Union[Path, bool],
53+
ty.Union[ty.List[Path], bool],
54+
ty.List[ty.List[Path]],
55+
)
56+
4957

5058
@attr.s(auto_attribs=True, kw_only=True)
5159
class SpecInfo:
@@ -343,6 +351,8 @@ def check_metadata(self):
343351
Also sets the default values when available and needed.
344352
345353
"""
354+
from ..utils.typing import TypeParser
355+
346356
supported_keys = {
347357
"allowed_values",
348358
"argstr",
@@ -361,6 +371,7 @@ def check_metadata(self):
361371
"formatter",
362372
"_output_type",
363373
}
374+
364375
for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")):
365376
mdata = fld.metadata
366377
# checking keys from metadata
@@ -377,16 +388,13 @@ def check_metadata(self):
377388
)
378389
# assuming that fields with output_file_template shouldn't have default
379390
if mdata.get("output_file_template"):
380-
if fld.type not in (
381-
Path,
382-
ty.Union[Path, bool],
383-
str,
384-
ty.Union[str, bool],
391+
if not any(
392+
TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES
385393
):
386394
raise TypeError(
387-
f"Type of '{fld.name}' should be either pathlib.Path or "
388-
f"typing.Union[pathlib.Path, bool] (not {fld.type}) because "
389-
f"it has a value for output_file_template ({mdata['output_file_template']!r})"
395+
f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} "
396+
f"(not {fld.type}) because it has a value for output_file_template "
397+
f"({mdata['output_file_template']!r})"
390398
)
391399
if fld.default not in [attr.NOTHING, True, False]:
392400
raise AttributeError(
@@ -443,7 +451,8 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
443451
input_value = getattr(inputs, fld.name, attr.NOTHING)
444452
if input_value is not attr.NOTHING:
445453
if TypeParser.contains_type(FileSet, fld.type):
446-
input_value = TypeParser(fld.type).coerce(input_value)
454+
label = f"output field '{fld.name}' of {self}"
455+
input_value = TypeParser(fld.type, label=label).coerce(input_value)
447456
additional_out[fld.name] = input_value
448457
elif (
449458
fld.default is None or fld.default == attr.NOTHING

pydra/engine/tests/test_helpers_file.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import typing as ty
22
import sys
33
from pathlib import Path
4+
import attr
5+
from unittest.mock import Mock
46
import pytest
57
from fileformats.generic import File
8+
from ..specs import SpecInfo, ShellSpec
9+
from ..task import ShellCommandTask
610
from ..helpers_file import (
711
ensure_list,
812
MountIndentifier,
913
copy_nested_files,
14+
template_update_single,
1015
)
1116

1217

@@ -343,3 +348,72 @@ def test_cifs_check():
343348
with MountIndentifier.patch_table(fake_table):
344349
for target, expected in cifs_targets:
345350
assert MountIndentifier.on_cifs(target) is expected
351+
352+
353+
def test_output_template(tmp_path):
354+
filename = str(tmp_path / "file.txt")
355+
with open(filename, "w") as f:
356+
f.write("hello from pydra")
357+
in_file = File(filename)
358+
359+
my_input_spec = SpecInfo(
360+
name="Input",
361+
fields=[
362+
(
363+
"in_file",
364+
attr.ib(
365+
type=File,
366+
metadata={
367+
"mandatory": True,
368+
"position": 1,
369+
"argstr": "",
370+
"help_string": "input file",
371+
},
372+
),
373+
),
374+
(
375+
"optional",
376+
attr.ib(
377+
type=ty.Union[Path, bool],
378+
default=False,
379+
metadata={
380+
"position": 2,
381+
"argstr": "--opt",
382+
"output_file_template": "{in_file}.out",
383+
"help_string": "optional file output",
384+
},
385+
),
386+
),
387+
],
388+
bases=(ShellSpec,),
389+
)
390+
391+
class MyCommand(ShellCommandTask):
392+
executable = "my"
393+
input_spec = my_input_spec
394+
395+
task = MyCommand(in_file=filename)
396+
assert task.cmdline == f"my {filename}"
397+
task.inputs.optional = True
398+
assert task.cmdline == f"my {filename} --opt {task.output_dir / 'file.out'}"
399+
task.inputs.optional = False
400+
assert task.cmdline == f"my {filename}"
401+
task.inputs.optional = "custom-file-out.txt"
402+
assert task.cmdline == f"my {filename} --opt custom-file-out.txt"
403+
404+
405+
def test_template_formatting(tmp_path):
406+
field = Mock()
407+
field.name = "grad"
408+
field.argstr = "--grad"
409+
field.metadata = {"output_file_template": ("{in_file}.bvec", "{in_file}.bval")}
410+
inputs = Mock()
411+
inputs_dict = {"in_file": "/a/b/c/file.txt", "grad": True}
412+
413+
assert template_update_single(
414+
field,
415+
inputs,
416+
inputs_dict_st=inputs_dict,
417+
output_dir=tmp_path,
418+
spec_type="input",
419+
) == [str(tmp_path / "file.bvec"), str(tmp_path / "file.bval")]

pydra/utils/tests/test_typing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,31 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
597597
assert out_file.parent != in_file.parent
598598
assert type(out_file.header) is MyHeader
599599
assert out_file.header.parent != in_file.header.parent
600+
601+
602+
def test_type_is_subclass1():
603+
assert TypeParser.is_subclass(ty.Type[File], type)
604+
605+
606+
def test_type_is_subclass2():
607+
assert not TypeParser.is_subclass(ty.Type[File], ty.Type[Json])
608+
609+
610+
def test_type_is_subclass3():
611+
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])
612+
613+
614+
def test_type_is_instance1():
615+
assert TypeParser.is_instance(File, ty.Type[File])
616+
617+
618+
def test_type_is_instance2():
619+
assert not TypeParser.is_instance(File, ty.Type[Json])
620+
621+
622+
def test_type_is_instance3():
623+
assert TypeParser.is_instance(Json, ty.Type[File])
624+
625+
626+
def test_type_is_instance4():
627+
assert TypeParser.is_instance(Json, type)

0 commit comments

Comments
 (0)