Skip to content

Commit 7245524

Browse files
committed
require all jobs in predecessor nodes to have completed successfully before successor nodes are run (will look to relax it with a partially generated state)
1 parent 3e5090e commit 7245524

File tree

7 files changed

+271
-192
lines changed

7 files changed

+271
-192
lines changed

pydra/engine/graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def node(self, name: str) -> NodeType:
9999
except KeyError:
100100
raise KeyError(f"Node {name!r} not found in graph") from None
101101

102+
def __getitem__(self, key):
103+
"""Get a node by its name."""
104+
return self.node(key)
105+
102106
@property
103107
def nodes_names_map(self) -> dict[str, NodeType]:
104108
"""Get a map of node names to nodes."""

pydra/engine/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def save(
191191
task_path: Path,
192192
result: "Result | None" = None,
193193
task: "Task[DefType] | None" = None,
194+
return_values: dict[str, ty.Any] | None = None,
194195
name_prefix: str = None,
195196
) -> None:
196197
"""
@@ -204,6 +205,8 @@ def save(
204205
Result to pickle and write
205206
task : :class:`~pydra.engine.core.TaskBase`
206207
Task to pickle and write
208+
return_values : :obj:`dict`
209+
Return values to pickle and write
207210
"""
208211
from pydra.engine.core import is_workflow
209212

@@ -233,6 +236,9 @@ def save(
233236
if task:
234237
with (task_path / f"{name_prefix}_task.pklz").open("wb") as fp:
235238
cp.dump(task, fp)
239+
if return_values:
240+
with (task_path / f"{name_prefix}_return_values.pklz").open("wb") as fp:
241+
cp.dump(task, fp)
236242

237243

238244
def copyfile_workflow(

pydra/engine/lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _get_value(
180180
def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
181181
if job.errored:
182182
raise ValueError(
183-
f"Cannot retrieve value for {self._field} from {self._node.name} as "
183+
f"Cannot retrieve value for {self._field!r} from {self._node.name} as "
184184
"the node errored"
185185
)
186186
res = job.result()

pydra/engine/specs.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,10 @@ def __call__(
293293
)
294294
raise
295295
if result.errored:
296-
if isinstance(self, WorkflowDef) or self._splitter:
297-
raise RuntimeError(f"Workflow {self} failed with errors")
298-
else:
299-
errors = result.errors
300-
raise RuntimeError(
301-
f"Task {self} failed @ {errors['time of crash']} with the following errors:\n"
302-
+ "\n".join(errors["error message"])
303-
)
296+
raise RuntimeError(
297+
f"Task {self} failed @ {result.errors['time of crash']} with the "
298+
"following errors:\n" + "\n".join(result.errors["error message"])
299+
)
304300
return result.outputs
305301

306302
def split(
@@ -697,6 +693,18 @@ def task(self):
697693
with open(task_pkl, "rb") as f:
698694
return cp.load(f)
699695

696+
@property
697+
def return_values(self):
698+
return_values_pkl = self.output_dir / "_return_values.pklz"
699+
if not return_values_pkl.exists():
700+
return None
701+
with open(return_values_pkl, "rb") as f:
702+
return cp.load(f)
703+
704+
@property
705+
def job(self):
706+
return self.task
707+
700708

701709
@attrs.define(kw_only=True)
702710
class RuntimeSpec:
@@ -798,40 +806,37 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
798806
outputs : Outputs
799807
The outputs of the task
800808
"""
801-
outputs = super()._from_task(task)
802-
# collecting outputs from tasks
803-
output_wf = {}
804-
lazy_field: lazy.LazyOutField
809+
805810
workflow: "Workflow" = task.return_values["workflow"]
806811
exec_graph: "DiGraph[NodeExecution]" = task.return_values["exec_graph"]
807-
nodes_dict = {n.name: n for n in exec_graph.nodes}
808-
for name, lazy_field in attrs_values(workflow.outputs).items():
809-
try:
810-
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
811-
if isinstance(val_out, StateArray):
812-
val_out = list(val_out) # implicitly combine state arrays
813-
output_wf[name] = val_out
814-
except (ValueError, AttributeError):
815-
output_wf[name] = None
816-
node: "NodeExecution" = nodes_dict[lazy_field._node.name]
817-
# checking if the tasks has predecessors that raises error
818-
if isinstance(node.errored, list):
819-
raise ValueError(f"Tasks {node._errored} raised an error")
820-
else:
821-
err_files = [(t.output_dir / "_error.pklz") for t in node.tasks]
822-
err_files = [f for f in err_files if f.exists()]
823-
if not err_files:
824-
raise
825-
raise ValueError(
826-
f"Task {lazy_field._node.name!r} raised an error, full crash report is "
827-
f"here: "
828-
+ (
829-
str(err_files[0])
830-
if len(err_files) == 1
831-
else "\n" + "\n".join(str(f) for f in err_files)
832-
)
812+
813+
# Check for errors in any of the workflow nodes
814+
if errored := [n for n in exec_graph.nodes if n.errored]:
815+
errors = []
816+
for node in errored:
817+
for task in node.errored.values():
818+
result = task.result()
819+
errors.append(
820+
f"Task {node.name!r} failed @ {result.errors['time of crash']} "
821+
"with the following errors:\n"
822+
+ "\n".join(result.errors["error message"])
833823
)
834-
outputs = attrs.evolve(outputs, **output_wf)
824+
raise RuntimeError(
825+
f"Workflow {workflow} failed with errors: " + "\n\n".join(errors)
826+
)
827+
828+
# Retrieve values from the output fields
829+
values = {}
830+
lazy_field: lazy.LazyOutField
831+
for name, lazy_field in attrs_values(workflow.outputs).items():
832+
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
833+
if isinstance(val_out, StateArray):
834+
val_out = list(val_out) # implicitly combine state arrays
835+
values[name] = val_out
836+
837+
# Set the values in the outputs object
838+
outputs = super()._from_task(task)
839+
outputs = attrs.evolve(outputs, **values)
835840
outputs._output_dir = task.output_dir
836841
return outputs
837842

pydra/engine/submitter.py

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class Submitter:
8383
messenger_args: dict[str, ty.Any]
8484
clean_stale_locks: bool
8585
run_start_time: datetime | None
86+
propagate_rerun: bool
8687

8788
def __init__(
8889
self,
@@ -94,6 +95,7 @@ def __init__(
9495
audit_flags: AuditFlag = AuditFlag.NONE,
9596
messengers: ty.Iterable[Messenger] | None = None,
9697
messenger_args: dict[str, ty.Any] | None = None,
98+
propagate_rerun: bool = True,
9799
clean_stale_locks: bool | None = None,
98100
**kwargs,
99101
):
@@ -121,6 +123,7 @@ def __init__(
121123

122124
self.cache_dir = cache_dir
123125
self.cache_locations = cache_locations
126+
self.propagate_rerun = propagate_rerun
124127
self.environment = environment if environment is not None else Native()
125128
self.loop = get_open_loop()
126129
self._own_loop = not self.loop.is_running()
@@ -188,6 +191,8 @@ def __call__(
188191
rerun : bool, optional
189192
Whether to force the re-computation of the task results even if existing
190193
results are found, by default False
194+
propagate_rerun : bool, optional
195+
Whether to propagate the rerun flag to all tasks in the workflow, by default True
191196
192197
Returns
193198
-------
@@ -312,12 +317,12 @@ def expand_workflow(self, workflow_task: "Task[WorkflowDef]", rerun: bool) -> No
312317
wf = workflow_task.definition.construct()
313318
# Generate the execution graph
314319
exec_graph = wf.execution_graph(submitter=self)
320+
workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph}
315321
tasks = self.get_runnable_tasks(exec_graph)
316322
while tasks or any(not n.done for n in exec_graph.nodes):
317323
for task in tasks:
318324
self.worker.run(task, rerun=rerun)
319325
tasks = self.get_runnable_tasks(exec_graph)
320-
workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph}
321326

322327
async def expand_workflow_async(
323328
self, workflow_task: "Task[WorkflowDef]", rerun: bool
@@ -333,6 +338,7 @@ async def expand_workflow_async(
333338
wf = workflow_task.definition.construct()
334339
# Generate the execution graph
335340
exec_graph = wf.execution_graph(submitter=self)
341+
workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph}
336342
# keep track of pending futures
337343
task_futures = set()
338344
tasks = self.get_runnable_tasks(exec_graph)
@@ -417,7 +423,6 @@ async def expand_workflow_async(
417423
task_futures.add(self.worker.run(task, rerun=rerun))
418424
task_futures = await self.worker.fetch_finished(task_futures)
419425
tasks = self.get_runnable_tasks(exec_graph)
420-
workflow_task.return_values = {"workflow": wf, "exec_graph": exec_graph}
421426

422427
def __enter__(self):
423428
return self
@@ -467,7 +472,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
467472
continue
468473
# since the list is sorted (breadth-first) we can stop
469474
# when we find a task that depends on any task that is already in tasks
470-
if set(graph.predecessors[node.name]).intersection(not_started):
475+
preds = set(graph.predecessors[node.name])
476+
if preds.intersection(not_started):
471477
break
472478
# Record if the node has not been started
473479
if not node.started:
@@ -619,6 +625,11 @@ def done(self) -> bool:
619625
# Check to see if any previously queued tasks have completed
620626
return not (self.queued or self.blocked or self.running)
621627

628+
@property
629+
def has_errored(self) -> bool:
630+
self.update_status()
631+
return bool(self.errored)
632+
622633
def update_status(self) -> None:
623634
"""Updates the status of the tasks in the node."""
624635
if not self.started:
@@ -729,56 +740,80 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
729740
List of tasks that are ready to run
730741
"""
731742
runnable: list["Task[DefType]"] = []
732-
if not self.started:
733-
self.start()
734-
# Check to see if any blocked tasks are now runnable/unrunnable
735-
for index, task in list(self.blocked.items()):
736-
pred: NodeExecution
737-
is_runnable = True
738-
# This is required for the commented-out code below
739-
# states_ind = (
740-
# list(self.node.state.states_ind[index].items())
741-
# if self.node.state
742-
# else []
743-
# )
744-
for pred in graph.predecessors[self.node.name]:
745-
if pred.node.state:
746-
# FIXME: These should be the only predecessor jobs that are required to have
747-
# completed before the job can be run, however, due to how the state
748-
# is currently built, all predecessors are required to have completed.
749-
# If/when this is relaxed, then the following code should be used instead.
750-
#
751-
# pred_states_ind = {
752-
# (k, i) for k, i in states_ind if k.startswith(pred.name + ".")
753-
# }
754-
# pred_inds = [
755-
# i
756-
# for i, ind in enumerate(pred.node.state.states_ind)
757-
# if set(ind.items()).issuperset(pred_states_ind)
758-
# ]
759-
pred_inds = list(range(len(pred.node.state.states_ind)))
760-
else:
761-
pred_inds = [None]
762-
if not all(i in pred.successful for i in pred_inds):
763-
is_runnable = False
764-
blocked = True
765-
if pred_errored := [i for i in pred_inds if i in pred.errored]:
766-
self.unrunnable[index].extend(
767-
[pred.errored[i] for i in pred_errored]
768-
)
769-
blocked = False
770-
if pred_unrunnable := [
771-
i for i in pred_inds if i in pred.unrunnable
772-
]:
773-
self.unrunnable[index].extend(
774-
[pred.unrunnable[i] for i in pred_unrunnable]
775-
)
776-
blocked = False
777-
if not blocked:
778-
del self.blocked[index]
779-
break
780-
if is_runnable:
781-
runnable.append(self.blocked.pop(index))
743+
predecessors: list["Task[DefType]"] = graph.predecessors[self.node.name]
744+
745+
# If there is a split, we need to wait for all predecessor nodes to finish
746+
# In theory, if the current splitter splits an already split state we should
747+
# only need to wait for the direct predecessor jobs to finish, however, this
748+
# would require a deep refactor of the State class as we need the whole state
749+
# in order to assign consistent state indices across the new split
750+
751+
# FIXME: The branch for handling partially completed/errored/unrunnable
752+
# predecessor nodes can't be used until the State class can be partially
753+
# initialised with lazy-fields.
754+
if True: # self.node.splitter:
755+
if unrunnable := [p for p in predecessors if p.errored or p.unrunnable]:
756+
self.unrunnable = {None: unrunnable}
757+
self.blocked = {}
758+
assert self.done
759+
else:
760+
if all(p.done for p in predecessors):
761+
if not self.started:
762+
self.start()
763+
if self.node.state is None:
764+
inds = [None]
765+
else:
766+
inds = list(range(len(self.node.state.states_ind)))
767+
if self.blocked:
768+
for i in inds:
769+
runnable.append(self.blocked.pop(i))
770+
else:
771+
if not self.started:
772+
self.start()
773+
774+
# Check to see if any blocked tasks are now runnable/unrunnable
775+
for index, task in list(self.blocked.items()):
776+
pred: NodeExecution
777+
is_runnable = True
778+
states_ind = (
779+
list(self.node.state.states_ind[index].items())
780+
if self.node.state
781+
else []
782+
)
783+
for pred in predecessors:
784+
if pred.node.state:
785+
pred_states_ind = {
786+
(k, i)
787+
for k, i in states_ind
788+
if k.startswith(pred.name + ".")
789+
}
790+
pred_inds = [
791+
i
792+
for i, ind in enumerate(pred.node.state.states_ind)
793+
if set(ind.items()).issuperset(pred_states_ind)
794+
]
795+
else:
796+
pred_inds = [None]
797+
if not all(i in pred.successful for i in pred_inds):
798+
is_runnable = False
799+
blocked = True
800+
if pred_errored := [
801+
pred.errored[i] for i in pred_inds if i in pred.errored
802+
]:
803+
self.unrunnable[index].extend(pred_errored)
804+
blocked = False
805+
if pred_unrunnable := [
806+
pred.unrunnable[i]
807+
for i in pred_inds
808+
if i in pred.unrunnable
809+
]:
810+
self.unrunnable[index].extend(pred_unrunnable)
811+
blocked = False
812+
if not blocked:
813+
del self.blocked[index]
814+
break
815+
if is_runnable:
816+
runnable.append(self.blocked.pop(index))
782817
self.queued.update({t.state_index: t for t in runnable})
783818
return list(self.queued.values())
784819

0 commit comments

Comments
 (0)