Skip to content

Commit 1cc4074

Browse files
committed
debugging combining states to preserve nested lists over staggered combines
1 parent 9df1191 commit 1cc4074

File tree

6 files changed

+148
-97
lines changed

6 files changed

+148
-97
lines changed

pydra/engine/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydra.engine import state
2323
from .lazy import LazyInField, LazyOutField
2424
from pydra.utils.hash import hash_function, Cache
25-
from pydra.utils.typing import TypeParser, StateArray
25+
from pydra.engine.state import State
2626
from .node import Node
2727
from datetime import datetime
2828
from fileformats.core import FileSet
@@ -710,8 +710,7 @@ def construct(
710710
)
711711
for outpt, outpt_lf in zip(output_fields, output_lazy_fields):
712712
# Automatically combine any uncombined state arrays into a single lists
713-
if TypeParser.get_origin(outpt_lf._type) is StateArray:
714-
outpt_lf._type = list[TypeParser.strip_splits(outpt_lf._type)[0]]
713+
outpt_lf._type = State.combine_state_arrays(outpt_lf._type)
715714
setattr(outputs, outpt.name, outpt_lf)
716715
else:
717716
if unset_outputs := [

pydra/engine/lazy.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as ty
22
import abc
3+
from operator import attrgetter
34
import attrs
45
from pydra.utils.typing import StateArray
56
from pydra.utils.hash import hash_single
@@ -152,54 +153,60 @@ def _get_value(
152153
value : Any
153154
the resolved value of the lazy-field
154155
"""
155-
from pydra.utils.typing import (
156-
TypeParser,
157-
) # pylint: disable=import-outside-toplevel
158156
from pydra.engine.state import StateIndex
159157

160158
if state_index is None:
161159
state_index = StateIndex()
162160

163-
task = graph.node(self._node.name).get_tasks(state_index)
164-
_, split_depth = TypeParser.strip_splits(self._type)
165-
166-
def get_nested(task: "Task[DefType]", depth: int):
167-
if isinstance(task, StateArray):
168-
val = [get_nested(task=t, depth=depth - 1) for t in task]
169-
if depth:
170-
val = StateArray[self._type](val)
171-
else:
172-
if task.errored:
173-
raise ValueError(
174-
f"Cannot retrieve value for {self._field} from {self._node.name} as "
175-
"the node errored"
176-
)
177-
res = task.result()
178-
if res is None:
179-
raise RuntimeError(
180-
f"Could not find results of '{task.name}' node in a sub-directory "
181-
f"named '{{{task.checksum}}}' in any of the cache locations.\n"
182-
+ "\n".join(str(p) for p in set(task.cache_locations))
183-
+ f"\n\nThis is likely due to hash changes in '{task.name}' node inputs. "
184-
f"Current values and hashes: {task.inputs}, "
185-
f"{task.definition._hash}\n\n"
186-
"Set loglevel to 'debug' in order to track hash changes "
187-
"throughout the execution of the workflow.\n\n "
188-
"These issues may have been caused by `bytes_repr()` methods "
189-
"that don't return stable hash values for specific object "
190-
"types across multiple processes (see bytes_repr() "
191-
'"singledispatch "function in pydra/utils/hash.py).'
192-
"You may need to write specific `bytes_repr()` "
193-
"implementations (see `pydra.utils.hash.register_serializer`) or a "
194-
"`__bytes_repr__()` dunder methods to handle one or more types in "
195-
"your interface inputs."
196-
)
197-
val = res.get_output_field(self._field)
198-
val = self._apply_cast(val)
161+
jobs = sorted(
162+
graph.node(self._node.name).matching_jobs(state_index),
163+
key=attrgetter("state_index"),
164+
)
165+
166+
def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
167+
if job.errored:
168+
raise ValueError(
169+
f"Cannot retrieve value for {self._field} from {self._node.name} as "
170+
"the node errored"
171+
)
172+
res = job.result()
173+
if res is None:
174+
raise RuntimeError(
175+
f"Could not find results of '{job.name}' node in a sub-directory "
176+
f"named '{{{job.checksum}}}' in any of the cache locations.\n"
177+
+ "\n".join(str(p) for p in set(job.cache_locations))
178+
+ f"\n\nThis is likely due to hash changes in '{job.name}' node inputs. "
179+
f"Current values and hashes: {job.inputs}, "
180+
f"{job.definition._hash}\n\n"
181+
"Set loglevel to 'debug' in order to track hash changes "
182+
"throughout the execution of the workflow.\n\n "
183+
"These issues may have been caused by `bytes_repr()` methods "
184+
"that don't return stable hash values for specific object "
185+
"types across multiple processes (see bytes_repr() "
186+
'"singledispatch "function in pydra/utils/hash.py).'
187+
"You may need to write specific `bytes_repr()` "
188+
"implementations (see `pydra.utils.hash.register_serializer`) or a "
189+
"`__bytes_repr__()` dunder methods to handle one or more types in "
190+
"your interface inputs."
191+
)
192+
val = res.get_output_field(self._field)
193+
val = self._apply_cast(val)
199194
return val
200195

201-
value = get_nested(task, depth=split_depth)
202-
return value
196+
if not self._node.state.depth(after_combine=False):
197+
assert len(jobs) == 1
198+
return retrieve_from_job(jobs[0])
199+
elif not self._node.state.keys_final: # all states are combined over
200+
return [retrieve_from_job(j) for j in jobs]
201+
elif self._node.state.combiner:
202+
values = StateArray()
203+
for ind in self._node.state.states_ind_final:
204+
values.append(
205+
[retrieve_from_job(j) for j in jobs if j.state_index.matches(ind)]
206+
)
207+
return values
208+
else:
209+
return StateArray(retrieve_from_job(j) for j in jobs)
203210

204211
@property
205212
def _source(self):

pydra/engine/node.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from copy import deepcopy
33
from enum import Enum
44
import attrs
5-
from pydra.utils.typing import TypeParser, StateArray
65
from . import lazy
76
from pydra.engine.helpers import (
87
attrs_values,
@@ -128,12 +127,7 @@ def lzout(self) -> OutputType:
128127
# types based on the number of states the node is split over and whether
129128
# it has a combiner
130129
if self._state:
131-
type_, _ = TypeParser.strip_splits(outpt._type)
132-
if self._state.combiner:
133-
type_ = list[type_]
134-
for _ in range(self._state.depth()):
135-
type_ = StateArray[type_]
136-
outpt._type = type_
130+
outpt._type = self._state.nest_output_type(outpt._type)
137131
# Flag the output lazy fields as being not typed checked (i.e. assigned to
138132
# another node's inputs) yet. This is used to prevent the user from changing
139133
# the type of the output after it has been accessed by connecting it to an

pydra/engine/state.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from copy import deepcopy
44
import itertools
55
from collections import OrderedDict
6+
from operator import itemgetter
67
from functools import reduce
78
import typing as ty
89
from . import helpers_state as hlpst
910
from .helpers import ensure_list, attrs_values
11+
from pydra.utils.typing import StateArray, TypeParser
1012

1113
# from .specs import BaseDef
1214
if ty.TYPE_CHECKING:
@@ -47,6 +49,18 @@ def __len__(self) -> int:
4749
def __iter__(self) -> ty.Generator[str, None, None]:
4850
return iter(self.indices)
4951

52+
def __getitem__(self, key: str) -> int:
53+
return self.indices[key]
54+
55+
def __lt__(self, other: "StateIndex") -> bool:
56+
if set(self.indices) != set(other.indices):
57+
raise ValueError(
58+
f"StateIndex {self} does not contain the same indices as {other}"
59+
)
60+
return sorted(self.indices.items(), key=itemgetter(0)) < sorted(
61+
other.indices.items(), key=itemgetter(0)
62+
)
63+
5064
def __repr__(self) -> str:
5165
return (
5266
"StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")"
@@ -79,6 +93,21 @@ def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
7993
"""
8094
return type(self)({k: v for k, v in self.indices.items() if k in state_names})
8195

96+
def missing(self, state_names: ty.Iterable[str]) -> ty.List[str]:
97+
"""Return the fields that are missing from the StateIndex
98+
99+
Parameters
100+
----------
101+
fields : list[str]
102+
the fields to check for
103+
104+
Returns
105+
-------
106+
list[str]
107+
the fields that are missing from the StateIndex
108+
"""
109+
return [f for f in state_names if f not in self.indices]
110+
82111
def matches(self, other: "StateIndex") -> bool:
83112
"""Check if the indices that are present in the other StateIndex match
84113
@@ -92,6 +121,8 @@ def matches(self, other: "StateIndex") -> bool:
92121
bool
93122
True if all the indices in the other StateIndex match
94123
"""
124+
if isinstance(other, dict):
125+
other = StateIndex(other)
95126
if not set(self.indices).issuperset(other.indices):
96127
raise ValueError(
97128
f"StateIndex {self} does not contain all the indices in {other}"
@@ -211,10 +242,6 @@ def __str__(self):
211242
@property
212243
def names(self):
213244
"""Return the names of the states."""
214-
# analysing states from connected tasks if inner_inputs
215-
if not hasattr(self, "keys_final"):
216-
self.prepare_states()
217-
self.prepare_inputs()
218245
previous_states_keys = {
219246
f"_{v.name}": v.keys_final for v in self.inner_inputs.values()
220247
}
@@ -265,6 +292,41 @@ def included(s):
265292
remaining_stack = [s for s in stack if included(s)]
266293
return depth + len(remaining_stack)
267294

295+
def nest_output_type(self, type_: type) -> type:
296+
"""Nests a type of an output field in a combination of lists and state-arrays
297+
based on the state's splitter and combiner
298+
299+
Parameters
300+
----------
301+
type_ : type
302+
the type of the output field
303+
304+
Returns
305+
-------
306+
type
307+
the nested type of the output field
308+
"""
309+
310+
state_array_depth = self.depth()
311+
312+
# If there is a combination, it will get flattened into a single list
313+
if self.depth(after_combine=False) > state_array_depth:
314+
type_ = list[type_]
315+
316+
# Nest the uncombined state arrays around the type
317+
for _ in range(state_array_depth):
318+
type_ = StateArray[type_]
319+
return type_
320+
321+
@classmethod
322+
def combine_state_arrays(cls, type_: type) -> type:
323+
"""Collapses (potentially nested) state array(s) into a single list"""
324+
if TypeParser.get_origin(type_) is StateArray:
325+
# Implicitly combine any remaining uncombined states into a single
326+
# list
327+
type_ = list[TypeParser.strip_splits(type_)[0]]
328+
return type_
329+
268330
@property
269331
def splitter(self):
270332
"""Get the splitter of the state."""

pydra/engine/submitter.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
from pathlib import Path
88
from tempfile import mkdtemp
9-
from copy import copy
9+
from copy import copy, deepcopy
1010
from datetime import datetime
1111
from collections import defaultdict
1212
import attrs
@@ -211,16 +211,15 @@ def __call__(
211211
from pydra.engine.specs import TaskDef
212212

213213
state = State(
214-
name="not-important",
214+
name="outer_split",
215215
definition=task_def,
216-
splitter=task_def._splitter,
217-
combiner=task_def._combiner,
216+
splitter=deepcopy(task_def._splitter),
217+
combiner=deepcopy(task_def._combiner),
218218
)
219-
list_depth = 2 if state.depth(after_combine=False) != state.depth() else 1
220219

221220
def wrap_type(tp):
222-
for _ in range(list_depth):
223-
tp = list[tp]
221+
tp = state.nest_output_type(tp)
222+
tp = state.combine_state_arrays(tp)
224223
return tp
225224

226225
output_types = {
@@ -568,22 +567,27 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
568567
self._tasks = {t.state_index: t for t in self._generate_tasks()}
569568
return self._tasks.values()
570569

571-
def get_tasks(
572-
self, index: StateIndex = StateIndex()
573-
) -> "Task | StateArray[Task[DefType]]":
574-
"""Get a task object for a given state index."""
575-
if not self.tasks:
576-
return StateArray([])
577-
task_index = next(iter(self._tasks)) if self._tasks else StateIndex()
578-
if len(task_index) > len(index):
579-
tasks = []
580-
for ind, task in self._tasks.items():
581-
if ind.matches(index):
582-
tasks.append(task)
583-
return StateArray(tasks)
584-
elif len(index) > len(task_index):
585-
index = index.subset(task_index)
586-
return self._tasks[index]
570+
def matching_jobs(self, index: StateIndex = StateIndex()) -> "StateArray[Task]":
571+
"""Get the jobs that match a given state index.
572+
573+
Parameters
574+
----------
575+
index : StateIndex, optional
576+
The state index of the task to get, by default StateIndex()
577+
"""
578+
matching = StateArray()
579+
if self.tasks:
580+
task_index = next(iter(self._tasks)) if self._tasks else StateIndex()
581+
if len(task_index) > len(index):
582+
# Select matching tasks and return them in nested state-array objects
583+
for ind, task in self._tasks.items():
584+
if ind.matches(index):
585+
matching.append(task)
586+
elif len(index) > len(task_index):
587+
matching.append(
588+
self._tasks[index.subset(task_index)]
589+
) # Return a single task
590+
return matching
587591

588592
@property
589593
def started(self) -> bool:
@@ -740,11 +744,7 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
740744
pred: NodeExecution
741745
is_runnable = True
742746
for pred in graph.predecessors[self.node.name]:
743-
pred_jobs = pred.get_tasks(index)
744-
if isinstance(pred_jobs, StateArray):
745-
pred_inds = [j.state_index for j in pred_jobs]
746-
else:
747-
pred_inds = [pred_jobs.state_index]
747+
pred_inds = [j.state_index for j in pred.matching_jobs(index)]
748748
if not all(i in pred.successful for i in pred_inds):
749749
is_runnable = False
750750
blocked = True

pydra/engine/tests/test_node_task.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def test_task_state_comb_1(plugin_dask_opt, tmp_path):
10331033
assert state.splitter_final is None
10341034
assert state.splitter_rpn_final == []
10351035

1036-
with Submitter(worker=plugin_dask_opt, cache_dir=tmp_path) as sub:
1036+
with Submitter(worker="debug", cache_dir=tmp_path) as sub:
10371037
results = sub(nn)
10381038
assert not results.errored, "\n".join(results.errors["error message"])
10391039

@@ -1147,7 +1147,7 @@ def test_task_state_comb_2(
11471147
assert state.splitter_rpn == state_rpn
11481148
assert state.combiner == state_combiner
11491149

1150-
with Submitter(worker=plugin, cache_dir=tmp_path) as sub:
1150+
with Submitter(worker="debug", cache_dir=tmp_path) as sub:
11511151
results = sub(nn)
11521152
assert not results.errored, "\n".join(results.errors["error message"])
11531153

@@ -1161,18 +1161,7 @@ def test_task_state_comb_2(
11611161
# it should give values of inputs that corresponds to the specific element
11621162
# results_verb = nn.result(return_inputs=True)
11631163

1164-
if state.splitter_rpn_final:
1165-
for i, res in enumerate(expected):
1166-
assert results.outputs.out == res
1167-
# results_verb
1168-
# for i, res_l in enumerate(expected_val):
1169-
# for j, res in enumerate(res_l):
1170-
# assert (results_verb[i][j][0], results_verb[i][j][1].output.out) == res
1171-
# if the combiner is full expected is "a flat list"
1172-
else:
1173-
assert results.outputs.out == expected
1174-
# for i, res in enumerate(expected_val):
1175-
# assert (results_verb[i][0], results_verb[i][1].output.out) == res
1164+
assert results.outputs.out == expected
11761165

11771166

11781167
def test_task_state_comb_singl_1(plugin, tmp_path):

0 commit comments

Comments
 (0)