Skip to content

Commit f4b08bf

Browse files
committed
fixed up the type-checking of fields with output_file_template
1 parent 62e3944 commit f4b08bf

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

pydra/engine/helpers_file.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,21 @@ def template_update_single(
151151
# if input_dict_st with state specific value is not available,
152152
# the dictionary will be created from inputs object
153153
from ..utils.typing import TypeParser # noqa
154-
from pydra.engine.specs import LazyField
155-
156-
VALID_TYPES = (str, ty.Union[str, bool], Path, ty.Union[Path, bool], LazyField)
154+
from pydra.engine.specs import LazyField, OUTPUT_TEMPLATE_TYPES
157155

158156
if inputs_dict_st is None:
159157
inputs_dict_st = attr.asdict(inputs, recurse=False)
160158

161159
if spec_type == "input":
162160
inp_val_set = inputs_dict_st[field.name]
163-
if inp_val_set is not attr.NOTHING and not TypeParser.is_instance(
164-
inp_val_set, VALID_TYPES
165-
):
166-
raise TypeError(
167-
f"'{field.name}' field has to be a Path instance or a bool, but {inp_val_set} set"
168-
)
169161
if isinstance(inp_val_set, bool) and field.type in (Path, str):
170162
raise TypeError(
171163
f"type of '{field.name}' is Path, consider using Union[Path, bool]"
172164
)
165+
if inp_val_set is not attr.NOTHING and not isinstance(LazyField):
166+
inp_val_set = TypeParser(ty.Union.__getitem__(OUTPUT_TEMPLATE_TYPES))(
167+
inp_val_set
168+
)
173169
elif spec_type == "output":
174170
if not TypeParser.contains_type(FileSet, field.type):
175171
raise TypeError(
@@ -179,22 +175,23 @@ def template_update_single(
179175
else:
180176
raise TypeError(f"spec_type can be input or output, but {spec_type} provided")
181177
# for inputs that the value is set (so the template is ignored)
182-
if spec_type == "input" and isinstance(inputs_dict_st[field.name], (str, Path)):
183-
return inputs_dict_st[field.name]
184-
elif spec_type == "input" and inputs_dict_st[field.name] is False:
185-
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
186-
return attr.NOTHING
187-
else: # inputs_dict[field.name] is True or spec_type is output
188-
value = _template_formatting(field, inputs, inputs_dict_st)
189-
# changing path so it is in the output_dir
190-
if output_dir and value is not attr.NOTHING:
191-
# should be converted to str, it is also used for input fields that should be str
192-
if type(value) is list:
193-
return [str(output_dir / Path(val).name) for val in value]
194-
else:
195-
return str(output_dir / Path(value).name)
196-
else:
178+
if spec_type == "input":
179+
if isinstance(inp_val_set, (Path, list)):
180+
return inp_val_set
181+
if inp_val_set is False:
182+
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
197183
return attr.NOTHING
184+
# inputs_dict[field.name] is True or spec_type is output
185+
value = _template_formatting(field, inputs, inputs_dict_st)
186+
# changing path so it is in the output_dir
187+
if output_dir and value is not attr.NOTHING:
188+
# should be converted to str, it is also used for input fields that should be str
189+
if type(value) is list:
190+
return [str(output_dir / Path(val).name) for val in value]
191+
else:
192+
return str(output_dir / Path(value).name)
193+
else:
194+
return attr.NOTHING
198195

199196

200197
def _template_formatting(field, inputs, inputs_dict_st):

pydra/engine/specs.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class MultiOutputType:
4646
MultiOutputObj = ty.Union[list, object, MultiOutputType]
4747
MultiOutputFile = ty.Union[File, ty.List[File], MultiOutputType]
4848

49+
OUTPUT_TEMPLATE_TYPES = (
50+
Path,
51+
ty.List[Path],
52+
ty.Union[Path, bool],
53+
ty.Union[ty.List[Path], bool],
54+
)
55+
4956

5057
@attr.s(auto_attribs=True, kw_only=True)
5158
class SpecInfo:
@@ -343,6 +350,8 @@ def check_metadata(self):
343350
Also sets the default values when available and needed.
344351
345352
"""
353+
from ..utils.typing import TypeParser
354+
346355
supported_keys = {
347356
"allowed_values",
348357
"argstr",
@@ -361,6 +370,7 @@ def check_metadata(self):
361370
"formatter",
362371
"_output_type",
363372
}
373+
364374
for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")):
365375
mdata = fld.metadata
366376
# checking keys from metadata
@@ -377,16 +387,13 @@ def check_metadata(self):
377387
)
378388
# assuming that fields with output_file_template shouldn't have default
379389
if mdata.get("output_file_template"):
380-
if fld.type not in (
381-
Path,
382-
ty.Union[Path, bool],
383-
str,
384-
ty.Union[str, bool],
390+
if not any(
391+
TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES
385392
):
386393
raise TypeError(
387-
f"Type of '{fld.name}' should be either pathlib.Path or "
388-
f"typing.Union[pathlib.Path, bool] (not {fld.type}) because "
389-
f"it has a value for output_file_template ({mdata['output_file_template']!r})"
394+
f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} "
395+
f"(not {fld.type}) because it has a value for output_file_template "
396+
f"({mdata['output_file_template']!r})"
390397
)
391398
if fld.default not in [attr.NOTHING, True, False]:
392399
raise AttributeError(

0 commit comments

Comments
 (0)