Skip to content

Commit bddb1f1

Browse files
committed
debugging test_workflow
1 parent cc8646b commit bddb1f1

File tree

3 files changed

+43
-34
lines changed

3 files changed

+43
-34
lines changed

pydra/engine/node.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -305,26 +305,6 @@ def _extract_input_el(self, inputs, inp_nm, ind):
305305
else:
306306
return getattr(inputs, inp_nm)[ind]
307307

308-
def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
309-
"""Split the definition into the different states it will be run over"""
310-
# TODO: doesn't work properly for more cmplicated wf (check if still an issue)
311-
if not self.state:
312-
return {None: self._definition}
313-
split_defs = {}
314-
for input_ind in self.state.inputs_ind:
315-
inputs_dict = {}
316-
for inp in set(self.input_names):
317-
if f"{self.name}.{inp}" in input_ind:
318-
inputs_dict[inp] = self._extract_input_el(
319-
inputs=self._definition,
320-
inp_nm=inp,
321-
ind=input_ind[f"{self.name}.{inp}"],
322-
)
323-
split_defs[StateIndex(input_ind)] = attrs.evolve(
324-
self._definition, **inputs_dict
325-
)
326-
return split_defs
327-
328308
# else:
329309
# # todo it never gets here
330310
# breakpoint()

pydra/engine/submitter.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,9 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
645645
name=self.node.name,
646646
)
647647
else:
648-
for index, split_defn in self.node._split_definition().items():
648+
for index, split_defn in self._split_definition().items():
649649
yield Task(
650-
definition=self._resolve_lazy_inputs(
651-
task_def=split_defn,
652-
state_index=index,
653-
),
650+
definition=split_defn,
654651
submitter=self.submitter,
655652
environment=self.node._environment,
656653
name=self.node.name,
@@ -686,6 +683,34 @@ def _resolve_lazy_inputs(
686683
)
687684
return attrs.evolve(task_def, **resolved)
688685

686+
def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
687+
"""Split the definition into the different states it will be run over"""
688+
# TODO: doesn't work properly for more cmplicated wf (check if still an issue)
689+
if not self.node.state:
690+
return {None: self.node._definition}
691+
split_defs = {}
692+
for input_ind in self.node.state.inputs_ind:
693+
inputs_dict = {}
694+
for inp in set(self.node.input_names):
695+
if f"{self.node.name}.{inp}" in input_ind:
696+
value = getattr(self.node._definition, inp)
697+
if isinstance(value, LazyField):
698+
inputs_dict[inp] = value._get_value(
699+
workflow=self.workflow,
700+
graph=self.graph,
701+
state_index=StateIndex(input_ind),
702+
)
703+
else:
704+
inputs_dict[inp] = self.node._extract_input_el(
705+
inputs=self.node._definition,
706+
inp_nm=inp,
707+
ind=input_ind[f"{self.node.name}.{inp}"],
708+
)
709+
split_defs[StateIndex(input_ind)] = attrs.evolve(
710+
self.node._definition, **inputs_dict
711+
)
712+
return split_defs
713+
689714
def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
690715
"""For a given node, check to see which tasks have been successfully run, are ready
691716
to run, can't be run due to upstream errors, or are blocked on other tasks to complete.

pydra/engine/tests/test_workflow.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ def Workflow(x, y):
11631163
mult = workflow.add(Multiply(x=add2x.out, y=add2y.out))
11641164
return mult.out
11651165

1166-
wf = Workflow.split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine("x")
1166+
wf = Workflow().split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine("x")
11671167

11681168
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
11691169
results = sub(wf)
@@ -1186,8 +1186,8 @@ def test_wf_3nd_ndst_2(plugin, tmpdir):
11861186

11871187
@workflow.define
11881188
def Workflow(x, y):
1189-
add2x = workflow.add(Add2().split("x", x=x))
1190-
add2y = workflow.add(Add2().split("x", x=y))
1189+
add2x = workflow.add(Add2().split("x", x=x), name="add2x")
1190+
add2y = workflow.add(Add2().split("x", x=y), name="add2y")
11911191
mult = workflow.add(Multiply(x=add2x.out, y=add2y.out).combine("add2x.x"))
11921192
return mult.out
11931193

@@ -1215,7 +1215,7 @@ def Workflow(x, y):
12151215
mult = workflow.add(Multiply(x=add2x.out, y=add2y.out))
12161216
return mult.out
12171217

1218-
wf = Workflow.split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine("y")
1218+
wf = Workflow().split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine("y")
12191219

12201220
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
12211221
results = sub(wf)
@@ -1245,7 +1245,7 @@ def Workflow(x, y):
12451245

12461246
wf = Workflow(x=[1, 2, 3], y=[11, 12])
12471247

1248-
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
1248+
with Submitter(worker="debug", cache_dir=tmpdir) as sub:
12491249
results = sub(wf)
12501250

12511251
assert not results.errored, "\n".join(results.errors["error message"])
@@ -1268,7 +1268,7 @@ def Workflow(x, y):
12681268
mult = workflow.add(Multiply(x=add2x.out, y=add2y.out))
12691269
return mult.out
12701270

1271-
wf = Workflow.split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine(["x", "y"])
1271+
wf = Workflow().split(["x", "y"], x=[1, 2, 3], y=[11, 12]).combine(["x", "y"])
12721272

12731273
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
12741274
results = sub(wf)
@@ -1322,7 +1322,11 @@ def Workflow(x, y, z):
13221322
addvar = workflow.add(FunAddVar3(a=add2x.out, b=add2y.out, c=z))
13231323
return addvar.out
13241324

1325-
wf = Workflow.split(["x", "y", "z"], x=[2, 3], y=[11, 12], z=[10, 100]).combine("y")
1325+
wf = (
1326+
Workflow()
1327+
.split(["x", "y", "z"], x=[2, 3], y=[11, 12], z=[10, 100])
1328+
.combine("y")
1329+
)
13261330

13271331
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
13281332
results = sub(wf)
@@ -1828,7 +1832,7 @@ def test_wf_ndst_singl_1(plugin, tmpdir):
18281832

18291833
@workflow.define
18301834
def Workflow(x, y):
1831-
mult = workflow.add(Multiply(y=y).split("x", x=x))
1835+
mult = workflow.add(Multiply(y=y).split("x", x=x), name="mult")
18321836
add2 = workflow.add(Add2(x=mult.out).combine("mult.x"))
18331837
return add2.out
18341838

@@ -1855,7 +1859,7 @@ def Workflow(x, y):
18551859
mult = workflow.add(Multiply(x=add2x.out, y=add2y.out))
18561860
return mult.out
18571861

1858-
wf = Workflow.split("x", x=[1, 2, 3], y=11)
1862+
wf = Workflow().split("x", x=[1, 2, 3], y=11)
18591863

18601864
with Submitter(worker=plugin, cache_dir=tmpdir) as sub:
18611865
results = sub(wf)

0 commit comments

Comments
 (0)