Skip to content

Commit b03307d

Browse files
committed
backed out of trying to allow execution of downstream nodes before/when upstream nodes are partially incompleted/have errored
1 parent f51063e commit b03307d

File tree

4 files changed

+42
-35
lines changed

4 files changed

+42
-35
lines changed

pydra/engine/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,8 @@ def _create_graph(
844844
# treated as a containers
845845
if (
846846
node.state
847-
and f"{node.name}.{field.name}" in node.state.splitter
847+
and f"{node.name}.{field.name}"
848+
in node.state._current_splitter_rpn
848849
):
849850
node.state._inner_cont_dim[
850851
f"{node.name}.{field.name}"

pydra/engine/node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,15 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
225225
):
226226
node: Node = val._node
227227
# variables that are part of inner splitters should be treated as a containers
228-
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:
229-
node.state._inner_cont_dim[f"{node.name}.{inpt_name}"] = 1
228+
if node.state and f"{node.name}.{val._field}" in node.state.splitter:
229+
node.state._inner_cont_dim[f"{node.name}.{val._field}"] = 1
230230
# adding task_name: (task.state, [a field from the connection]
231231
if node.name not in upstream_states:
232-
upstream_states[node.name] = (node.state, [inpt_name])
232+
upstream_states[node.name] = (node.state, [val._field])
233233
else:
234234
# if the task already exist in other_state,
235235
# additional field name should be added to the list of fields
236-
upstream_states[node.name][1].append(inpt_name)
236+
upstream_states[node.name][1].append(val._field)
237237
return upstream_states
238238

239239
# else:

pydra/engine/state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def prepare_states(
885885
st: State
886886
for nm, (st, _) in self.other_states.items():
887887
self.inputs.update(st.inputs)
888-
self.cont_dim.update(st.cont_dim)
888+
self.cont_dim.update(st.cont_dim_all)
889889

890890
self.prepare_states_ind()
891891
self.prepare_states_val()
@@ -1225,11 +1225,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
12251225
Any
12261226
specific element of the input field
12271227
"""
1228-
if f"{self.name}.{field_name}" in self.cont_dim:
1228+
if f"{self.name}.{field_name}" in self.cont_dim_all:
12291229
return list(
12301230
hlpst.flatten(
12311231
ensure_list(value),
1232-
max_depth=self.cont_dim[f"{self.name}.{field_name}"],
1232+
max_depth=self.cont_dim_all[f"{self.name}.{field_name}"],
12331233
)
12341234
)[ind]
12351235
else:

pydra/engine/submitter.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,9 @@ def start(self) -> None:
570570
if self.state:
571571
values = {}
572572
for name, value in self.node.state_values.items():
573-
if name in self.node.state.names:
573+
if name in self.node.state.current_splitter_rpn:
574+
if name in ("*", "."):
575+
continue
574576
if isinstance(value, LazyField):
575577
values[name] = value._get_value(
576578
workflow=self.workflow, graph=self.graph
@@ -692,23 +694,20 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
692694
if not self.node.state:
693695
return {None: self.node._definition}
694696
split_defs = []
695-
for input_ind in self.node.state.inputs_ind:
697+
for index, vals in zip(self.node.state.inputs_ind, self.node.state.states_val):
696698
resolved = {}
697699
for inpt_name in set(self.node.input_names):
698700
value = getattr(self._definition, inpt_name)
699701
state_key = f"{self.node.name}.{inpt_name}"
700-
if isinstance(value, LazyField):
701-
resolved[inpt_name] = value._get_value(
702-
workflow=self.workflow,
703-
graph=self.graph,
704-
state_index=input_ind.get(state_key),
705-
)
706-
elif state_key in input_ind:
707-
resolved[inpt_name] = self.node.state._get_element(
708-
value=value,
709-
field_name=inpt_name,
710-
ind=input_ind[state_key],
711-
)
702+
try:
703+
resolved[inpt_name] = vals[state_key]
704+
except KeyError:
705+
if isinstance(value, LazyField):
706+
resolved[inpt_name] = value._get_value(
707+
workflow=self.workflow,
708+
graph=self.graph,
709+
state_index=index.get(state_key),
710+
)
712711
split_defs.append(attrs.evolve(self.node._definition, **resolved))
713712
return split_defs
714713

@@ -736,21 +735,28 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
736735
for index, task in list(self.blocked.items()):
737736
pred: NodeExecution
738737
is_runnable = True
739-
states_ind = (
740-
list(self.node.state.states_ind[index].items())
741-
if self.node.state
742-
else []
743-
)
738+
# This is required for the commented-out code below
739+
# states_ind = (
740+
# list(self.node.state.states_ind[index].items())
741+
# if self.node.state
742+
# else []
743+
# )
744744
for pred in graph.predecessors[self.node.name]:
745745
if pred.node.state:
746-
pred_states_ind = {
747-
(k, i) for k, i in states_ind if k.startswith(pred.name + ".")
748-
}
749-
pred_inds = [
750-
i
751-
for i, ind in enumerate(pred.node.state.states_ind)
752-
if set(ind.items()).issuperset(pred_states_ind)
753-
]
746+
# FIXME: These should be the only predecessor jobs that are required to have
747+
# completed before the job can be run, however, due to how the state
748+
# is currently built, all predecessors are required to have completed.
749+
# If/when this is relaxed, then the following code should be used instead.
750+
#
751+
# pred_states_ind = {
752+
# (k, i) for k, i in states_ind if k.startswith(pred.name + ".")
753+
# }
754+
# pred_inds = [
755+
# i
756+
# for i, ind in enumerate(pred.node.state.states_ind)
757+
# if set(ind.items()).issuperset(pred_states_ind)
758+
# ]
759+
pred_inds = list(range(len(pred.node.state.states_ind)))
754760
else:
755761
pred_inds = [None]
756762
if not all(i in pred.successful for i in pred_inds):

0 commit comments

Comments
 (0)