Skip to content

Commit 92c0a7d

Browse files
authored
Merge pull request #389 from chasejohnson3/master
Added allowed types for an output_spec field
2 parents 044ccba + 5147e05 commit 92c0a7d

File tree

3 files changed

+127
-10
lines changed

3 files changed

+127
-10
lines changed

pydra/engine/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ def _collect_outputs(self, output_dir):
486486
run_output = self.output_
487487
output_klass = make_klass(self.output_spec)
488488
output = output_klass(**{f.name: None for f in attr.fields(output_klass)})
489-
other_output = output.collect_additional_outputs(self.inputs, output_dir)
489+
other_output = output.collect_additional_outputs(
490+
self.inputs, output_dir, run_output
491+
)
490492
return attr.evolve(output, **run_output, **other_output)
491493

492494
def split(self, splitter, overwrite=False, **kwargs):

pydra/engine/specs.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __setattr__(self, name, value):
9595
# validate all fields that have set a validator
9696
attr.validate(self)
9797

98-
def collect_additional_outputs(self, inputs, output_dir):
98+
def collect_additional_outputs(self, inputs, output_dir, outputs):
9999
"""Get additional outputs."""
100100
return {}
101101

@@ -439,12 +439,21 @@ class ShellOutSpec:
439439
stderr: ty.Union[File, str]
440440
"""The process' standard input."""
441441

442-
def collect_additional_outputs(self, inputs, output_dir):
442+
def collect_additional_outputs(self, inputs, output_dir, outputs):
443443
"""Collect additional outputs from shelltask output_spec."""
444444
additional_out = {}
445445
for fld in attr_fields(self):
446446
if fld.name not in ["return_code", "stdout", "stderr"]:
447-
if fld.type in [File, MultiOutputFile]:
447+
if fld.type in [
448+
File,
449+
MultiOutputFile,
450+
Directory,
451+
int,
452+
float,
453+
bool,
454+
str,
455+
list,
456+
]:
448457
# assuming that field should have either default or metadata, but not both
449458
if (
450459
fld.default is None or fld.default == attr.NOTHING
@@ -457,9 +466,21 @@ def collect_additional_outputs(self, inputs, output_dir):
457466
fld, output_dir
458467
)
459468
elif fld.metadata:
460-
additional_out[fld.name] = self._field_metadata(
461-
fld, inputs, output_dir
462-
)
469+
if (
470+
fld.type in [int, float, bool, str, list]
471+
and "callable" not in fld.metadata
472+
):
473+
raise AttributeError(
474+
f"{fld.type} has to have a callable in metadata"
475+
)
476+
else:
477+
additional_out[fld.name] = self._field_metadata(
478+
fld, inputs, output_dir, outputs
479+
)
480+
# else:
481+
# additional_out[fld.name] = self._field_metadata(
482+
# fld, inputs, output_dir, outputs
483+
# )
463484
else:
464485
raise Exception("not implemented (collect_additional_output)")
465486
return additional_out
@@ -486,7 +507,7 @@ def generated_output_names(self, inputs, output_dir):
486507
output_names.append(fld.name)
487508
elif (
488509
fld.metadata
489-
and self._field_metadata(fld, inputs, output_dir)
510+
and self._field_metadata(fld, inputs, output_dir, outputs=None)
490511
!= attr.NOTHING
491512
):
492513
output_names.append(fld.name)
@@ -522,7 +543,7 @@ def _field_defaultvalue(self, fld, output_dir):
522543
else:
523544
raise AttributeError(f"no file matches {default.name}")
524545

525-
def _field_metadata(self, fld, inputs, output_dir):
546+
def _field_metadata(self, fld, inputs, output_dir, outputs=None):
526547
"""Collect output file if metadata specified."""
527548
if self._check_requires(fld, inputs) is False:
528549
return attr.NOTHING
@@ -551,6 +572,10 @@ def _field_metadata(self, fld, inputs, output_dir):
551572
call_args_val[argnm] = output_dir
552573
elif argnm == "inputs":
553574
call_args_val[argnm] = inputs
575+
elif argnm == "stdout":
576+
call_args_val[argnm] = outputs["stdout"]
577+
elif argnm == "stderr":
578+
call_args_val[argnm] = outputs["stderr"]
554579
else:
555580
try:
556581
call_args_val[argnm] = getattr(inputs, argnm)

pydra/engine/tests/test_shelltask.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os, sys
44
import pytest
55
from pathlib import Path
6-
6+
import re
77

88
from ..task import ShellCommandTask
99
from ..submitter import Submitter
@@ -2837,6 +2837,96 @@ def test_shell_cmd_outputspec_6a(tmpdir, plugin, results_function):
28372837
assert res.output.new_files.exists()
28382838

28392839

2840+
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
2841+
def test_shell_cmd_outputspec_7a(tmpdir, plugin, results_function):
2842+
"""
2843+
customised output_spec, adding int and str to the output,
2844+
requiring two callables with parameters stdout and stderr
2845+
"""
2846+
cmd = "echo"
2847+
args = ["newfile_1.txt", "newfile_2.txt"]
2848+
2849+
def get_file_index(stdout):
2850+
stdout = re.sub(r".*_", "", stdout)
2851+
stdout = re.sub(r".txt", "", stdout)
2852+
print(stdout)
2853+
return int(stdout)
2854+
2855+
def get_stderr(stderr):
2856+
return f"stderr: {stderr}"
2857+
2858+
my_output_spec = SpecInfo(
2859+
name="Output",
2860+
fields=[
2861+
(
2862+
"out1",
2863+
attr.ib(
2864+
type=File,
2865+
metadata={
2866+
"output_file_template": "{args}",
2867+
"help_string": "output file",
2868+
},
2869+
),
2870+
),
2871+
(
2872+
"out_file_index",
2873+
attr.ib(
2874+
type=int,
2875+
metadata={"help_string": "output file", "callable": get_file_index},
2876+
),
2877+
),
2878+
(
2879+
"stderr_field",
2880+
attr.ib(
2881+
type=str,
2882+
metadata={
2883+
"help_string": "The standard error output",
2884+
"callable": get_stderr,
2885+
},
2886+
),
2887+
),
2888+
],
2889+
bases=(ShellOutSpec,),
2890+
)
2891+
2892+
shelly = ShellCommandTask(
2893+
name="shelly", executable=cmd, args=args, output_spec=my_output_spec
2894+
).split("args")
2895+
2896+
results = results_function(shelly, plugin)
2897+
for index, res in enumerate(results):
2898+
assert res.output.out_file_index == index + 1
2899+
assert res.output.stderr_field == f"stderr: {res.output.stderr}"
2900+
2901+
2902+
def test_shell_cmd_outputspec_7b_error():
2903+
"""
2904+
customised output_spec, adding Int to the output,
2905+
requiring a function to collect output
2906+
"""
2907+
cmd = "echo"
2908+
args = ["newfile_1.txt", "newfile_2.txt"]
2909+
2910+
my_output_spec = SpecInfo(
2911+
name="Output",
2912+
fields=[
2913+
(
2914+
"out",
2915+
attr.ib(
2916+
type=int, metadata={"help_string": "output file", "value": "val"}
2917+
),
2918+
)
2919+
],
2920+
bases=(ShellOutSpec,),
2921+
)
2922+
shelly = ShellCommandTask(
2923+
name="shelly", executable=cmd, args=args, output_spec=my_output_spec
2924+
).split("args")
2925+
with pytest.raises(Exception) as e:
2926+
shelly()
2927+
assert "has to have a callable" in str(e.value)
2928+
2929+
28402930
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
28412931
def test_shell_cmd_state_outputspec_1(plugin, results_function, tmpdir):
28422932
"""

0 commit comments

Comments
 (0)