Skip to content

Commit 5c101a8

Browse files
committed
debugging test_workflow
1 parent 4a51195 commit 5c101a8

File tree

3 files changed

+104
-118
lines changed

3 files changed

+104
-118
lines changed

pydra/engine/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def construct(
636636

637637
# Initialise the outputs of the workflow
638638
outputs = definition.Outputs(
639-
**{f.name: attrs.NOTHING for f in attrs.fields(definition.Outputs)}
639+
**{f.name: attrs.NOTHING for f in list_fields(definition.Outputs)}
640640
)
641641

642642
# Initialise the lzin fields

pydra/engine/specs.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,49 @@ class TaskOutputs:
6767

6868
RESERVED_FIELD_NAMES = ("inputs",)
6969

70+
_output_dir: Path = attrs.field(default=None, init=False, repr=False)
71+
7072
@property
7173
def inputs(self):
7274
"""The inputs object associated with a lazy-outputs object"""
7375
return self._get_node().inputs
7476

7577
@classmethod
76-
def _from_defaults(cls) -> Self:
77-
"""Create an output object from the default values of the fields"""
78-
return cls(
78+
def _from_task(cls, task: "Task[DefType]") -> Self:
79+
"""Collect the outputs of a task. This is just an abstract base method that
80+
should be used by derived classes to set default values for the outputs.
81+
82+
Parameters
83+
----------
84+
task : Task[DefType]
85+
The task whose outputs are being collected.
86+
87+
Returns
88+
-------
89+
outputs : Outputs
90+
The outputs of the task
91+
"""
92+
outputs = cls(
7993
**{
8094
f.name: (
8195
f.default.factory()
8296
if isinstance(f.default, attrs.Factory)
8397
else f.default
8498
)
8599
for f in attrs_fields(cls)
100+
if not f.name.startswith("_")
86101
}
87102
)
103+
outputs._output_dir = task.output_dir
104+
return outputs
105+
106+
@property
107+
def _results(self) -> "Result[Self]":
108+
results_path = self._output_dir / "_task.pklz"
109+
if not results_path.exists():
110+
raise FileNotFoundError(f"Task results file {results_path} not found")
111+
with open(results_path, "rb") as f:
112+
return cp.load(f)
88113

89114
def _get_node(self):
90115
try:
@@ -718,7 +743,7 @@ def _from_task(cls, task: "Task[PythonDef]") -> Self:
718743
outputs : Outputs
719744
The outputs of the task in dataclass
720745
"""
721-
outputs = cls._from_defaults()
746+
outputs = super()._from_task(task)
722747
for name, val in task.return_values.items():
723748
setattr(outputs, name, val)
724749
return outputs
@@ -739,7 +764,7 @@ def _run(self, task: "Task[PythonDef]", rerun: bool = True) -> None:
739764
# Run the actual function
740765
returned = self.function(**inputs)
741766
# Collect the outputs and save them into the task.return_values dictionary
742-
task.return_values = {f.name: f.default for f in attrs.fields(self.Outputs)}
767+
task.return_values = {f.name: f.default for f in list_fields(self.Outputs)}
743768
return_names = list(task.return_values)
744769
if returned is None:
745770
task.return_values = {nm: None for nm in return_names}
@@ -773,7 +798,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
773798
outputs : Outputs
774799
The outputs of the task
775800
"""
776-
outputs = cls._from_defaults()
801+
outputs = super()._from_task(task)
777802
# collecting outputs from tasks
778803
output_wf = {}
779804
lazy_field: lazy.LazyOutField
@@ -806,7 +831,9 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
806831
else "\n" + "\n".join(str(f) for f in err_files)
807832
)
808833
)
809-
return attrs.evolve(outputs, **output_wf)
834+
outputs = attrs.evolve(outputs, **output_wf)
835+
outputs._output_dir = task.output_dir
836+
return outputs
810837

811838

812839
WorkflowOutputsType = ty.TypeVar("OutputType", bound=WorkflowOutputs)
@@ -876,7 +903,7 @@ def _from_task(cls, task: "Task[ShellDef]") -> Self:
876903
outputs : ShellOutputs
877904
The outputs of the shell process
878905
"""
879-
outputs = cls._from_defaults()
906+
outputs = super()._from_task(task)
880907
fld: shell.out
881908
for fld in list_fields(cls):
882909
if fld.name in ["return_code", "stdout", "stderr"]:
@@ -905,21 +932,6 @@ def _from_task(cls, task: "Task[ShellDef]") -> Self:
905932
)
906933
return outputs
907934

908-
# @classmethod
909-
# def _from_defaults(cls) -> Self:
910-
# """Create an output object from the default values of the fields"""
911-
# defaults = {}
912-
# for field in attrs_fields(cls):
913-
# if isinstance(field.default, attrs.Factory):
914-
# defaults[field.name] = field.default.factory()
915-
# elif TypeParser.contains_type(FileSet, field.type):
916-
# # Will be set by the templating code
917-
# defaults[field.name] = attrs.NOTHING
918-
# else:
919-
# defaults[field.name] = field.default
920-
921-
# return cls(**defaults)
922-
923935
@classmethod
924936
def _resolve_default_value(cls, fld: shell.out, output_dir: Path) -> ty.Any:
925937
"""Resolve path and glob expr default values relative to the output dir"""

0 commit comments

Comments
 (0)