Skip to content

Commit b78521c

Browse files
committed
fixed serialization of task definitions
1 parent 98718c8 commit b78521c

File tree

6 files changed

+27
-8
lines changed

6 files changed

+27
-8
lines changed

pydra/compose/base/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ def ensure_field_objects(
114114
out_kwds = copy(out)
115115
if "help" not in out_kwds:
116116
out_kwds["help"] = output_helps.get(output_name, "")
117-
outputs[output_name] = out_type(
117+
if "path_template" in out_kwds:
118+
from pydra.compose.shell import outarg
119+
120+
out_type_ = outarg
121+
else:
122+
out_type_ = out_type
123+
outputs[output_name] = out_type_(
118124
name=output_name,
119125
**out_kwds,
120126
)

pydra/compose/base/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Outputs:
3737
"""Base class for all output definitions"""
3838

3939
RESERVED_FIELD_NAMES = ("inputs",)
40+
BASE_ATTRS = ()
4041

4142
_cache_dir: Path = attrs.field(default=None, init=False, repr=False)
4243
_node = attrs.field(default=None, init=False, repr=False)
@@ -178,6 +179,7 @@ def _task_type(cls) -> str:
178179
_hashes = attrs.field(default=None, init=False, eq=False, repr=False)
179180

180181
RESERVED_FIELD_NAMES = ("split", "combine")
182+
BASE_ATTRS = ()
181183

182184
def __call__(
183185
self,

pydra/compose/shell/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def make(
183183
class_name = f"_{class_name}"
184184

185185
# Add in fields from base classes
186-
parsed_inputs.update({n: getattr(Task, n) for n in Task.BASE_NAMES})
187-
parsed_outputs.update({n: getattr(Outputs, n) for n in Outputs.BASE_NAMES})
186+
parsed_inputs.update({n: getattr(Task, n) for n in Task.BASE_ATTRS})
187+
parsed_outputs.update({n: getattr(Outputs, n) for n in Outputs.BASE_ATTRS})
188188

189189
if "executable" in parsed_inputs:
190190
raise ValueError(

pydra/compose/shell/field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _callable_validator(self, _, value):
134134
"stdout",
135135
"stderr",
136136
]
137-
): # shell.Outputs.BASE_NAMES
137+
): # shell.Outputs.BASE_ATTRS
138138
raise ValueError(
139139
"A shell output field must have either a callable or a path_template"
140140
)

pydra/compose/shell/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
class ShellOutputs(base.Outputs):
4747
"""Output task of a generic shell process."""
4848

49-
BASE_NAMES = ["return_code", "stdout", "stderr"]
49+
BASE_ATTRS = ["return_code", "stdout", "stderr"]
5050
RETURN_CODE_HELP = """The process' exit code."""
5151
STDOUT_HELP = """The standard output stream produced by the command."""
5252
STDERR_HELP = """The standard error stream produced by the command."""
@@ -234,7 +234,7 @@ class ShellTask(base.Task[ShellOutputsType]):
234234

235235
_executor_name = "executable"
236236

237-
BASE_NAMES = ["append_args"]
237+
BASE_ATTRS = ("append_args",)
238238

239239
EXECUTABLE_HELP = (
240240
"the first part of the command, can be a string, "

pydra/utils/general.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,22 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
602602
dict[str, ty.Any]
603603
A dictionary representation of the Pydra task.
604604
"""
605+
from pydra.compose.base import Out
606+
605607
input_fields = task_fields(task_def)
606608
executor = input_fields.pop(task_def._executor_name).default
607-
input_dicts = [attrs.asdict(i, filter=_filter_defaults) for i in input_fields]
609+
input_dicts = [
610+
attrs.asdict(i, filter=_filter_defaults)
611+
for i in input_fields
612+
if (
613+
not isinstance(i, Out) # filter out outarg fields
614+
and i.name not in task_def.BASE_ATTRS
615+
)
616+
]
608617
output_dicts = [
609-
attrs.asdict(o, filter=_filter_defaults) for o in task_fields(task_def.Outputs)
618+
attrs.asdict(o, filter=_filter_defaults)
619+
for o in task_fields(task_def.Outputs)
620+
if o.name not in task_def.Outputs.BASE_ATTRS
610621
]
611622
dct = {
612623
"type": task_def._task_type(),

0 commit comments

Comments
 (0)