Skip to content

Commit eee43b7

Browse files
committed
fixing workflows with first/intermediate task that has a full combiner; adding tests
1 parent a847b39 commit eee43b7

File tree

3 files changed

+90
-16
lines changed

3 files changed

+90
-16
lines changed

pydra/engine/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def checksum(self):
222222
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
223223

224224
input_hash = self.inputs.hash
225-
if self.state is None:
225+
if self.state is None or self.state.splitter_rpn == []:
226226
self._checksum = create_checksum(self.__class__.__name__, input_hash)
227227
else:
228228
# including splitter in the hash
@@ -329,7 +329,7 @@ def cache_locations(self, locations):
329329
@property
330330
def output_dir(self):
331331
"""Get the filesystem path where outputs will be written."""
332-
if self.state:
332+
if self.state and self.state.splitter_rpn:
333333
return [self._cache_dir / checksum for checksum in self.checksum_states()]
334334
return self._cache_dir / self.checksum
335335

@@ -342,7 +342,7 @@ def __call__(self, submitter=None, plugin=None, rerun=False, **kwargs):
342342
plugin = plugin or self.plugin
343343
if plugin:
344344
submitter = Submitter(plugin=plugin)
345-
elif self.state:
345+
elif self.state and self.state.splitter_rpn:
346346
submitter = Submitter()
347347

348348
if submitter:
@@ -512,7 +512,7 @@ def done(self):
512512
# if any of the field is lazy, there is no need to check results
513513
if is_lazy(self.inputs):
514514
return False
515-
if self.state:
515+
if self.state and self.state.splitter_rpn:
516516
# TODO: only check for needed state result
517517
if self.result() and all(self.result()):
518518
return True
@@ -556,7 +556,7 @@ def result(self, state_index=None):
556556
"""
557557
# TODO: check if result is available in load_result and
558558
# return a future if not
559-
if self.state:
559+
if self.state and self.state.splitter_rpn:
560560
if state_index is None:
561561
# if state_index=None, collecting all results
562562
if self.state.combiner:

pydra/engine/submitter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def __call__(self, runnable, cache_locations=None, rerun=False):
5050
runnable.inputs._graph_checksums = [
5151
nd.checksum for nd in runnable.graph_sorted
5252
]
53-
if is_workflow(runnable) and runnable.state is None:
53+
if is_workflow(runnable) and (
54+
runnable.state is None or runnable.state.splitter_rpn == []
55+
):
5456
self.loop.run_until_complete(self.submit_workflow(runnable, rerun=rerun))
5557
else:
5658
self.loop.run_until_complete(self.submit(runnable, wait=True, rerun=rerun))
@@ -91,7 +93,7 @@ async def submit(self, runnable, wait=False, rerun=False):
9193
9294
"""
9395
futures = set()
94-
if runnable.state:
96+
if runnable.state and runnable.state.splitter_rpn:
9597
runnable.state.prepare_states(runnable.inputs)
9698
runnable.state.prepare_inputs()
9799
logger.debug(
@@ -154,7 +156,9 @@ async def _run_workflow(self, wf, rerun=False):
154156
task.inputs.retrieve_values(wf)
155157
# checksum has to be updated, so resetting
156158
task._checksum = None
157-
if is_workflow(task) and not task.state:
159+
if is_workflow(task) and (
160+
not task.state or task.state.splitter_rpn == []
161+
):
158162
await self.submit_workflow(task, rerun=rerun)
159163
else:
160164
for fut in await self.submit(task, rerun=rerun):

pydra/engine/tests/test_workflow.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,81 @@ def test_wf_ndst_6(plugin):
757757
assert wf.output_dir.exists()
758758

759759

760+
@pytest.mark.parametrize("plugin", Plugins)
761+
def test_wf_ndst_7(plugin):
762+
""" workflow with two tasks, outer splitter and (full) combiner for first node only"""
763+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
764+
wf.add(multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y).split("x").combine("x"))
765+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
766+
wf.inputs.x = [1, 2, 3]
767+
wf.inputs.y = 11
768+
wf.set_output([("out", wf.iden.lzout.out)])
769+
wf.plugin = plugin
770+
771+
with Submitter(plugin=plugin) as sub:
772+
sub(wf)
773+
774+
results = wf.result()
775+
assert results.output.out[0] == [11, 22, 33]
776+
777+
# checking the output directory
778+
assert wf.output_dir.exists()
779+
780+
781+
@pytest.mark.parametrize("plugin", Plugins)
782+
def test_wf_ndst_8(plugin):
783+
""" workflow with two tasks, outer splitter and (partial) combiner for first task only"""
784+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
785+
wf.add(
786+
multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y).split(["x", "y"]).combine("x")
787+
)
788+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
789+
wf.inputs.x = [1, 2, 3]
790+
wf.inputs.y = [11, 12]
791+
wf.set_output([("out", wf.iden.lzout.out)])
792+
wf.plugin = plugin
793+
794+
with Submitter(plugin=plugin) as sub:
795+
sub(wf)
796+
797+
results = wf.result()
798+
assert results.output.out[0] == [11, 22, 33]
799+
assert results.output.out[1] == [12, 24, 36]
800+
801+
# checking the output directory
802+
assert wf.output_dir.exists()
803+
804+
805+
@pytest.mark.parametrize("plugin", Plugins)
806+
def test_wf_ndst_9(plugin):
807+
""" workflow with two tasks, outer splitter and (full) combiner for first task only"""
808+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
809+
wf.add(
810+
multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y)
811+
.split(["x", "y"])
812+
.combine(["x", "y"])
813+
)
814+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
815+
wf.inputs.x = [1, 2, 3]
816+
wf.inputs.y = [11, 12]
817+
wf.set_output([("out", wf.iden.lzout.out)])
818+
wf.plugin = plugin
819+
820+
with Submitter(plugin=plugin) as sub:
821+
sub(wf)
822+
823+
results = wf.result()
824+
assert results.output.out[0] == [11, 12, 22, 24, 33, 36]
825+
826+
# checking the output directory
827+
assert wf.output_dir.exists()
828+
829+
760830
# workflows with structures A -> C, B -> C
761831

762832

763833
@pytest.mark.parametrize("plugin", Plugins)
764-
def test_wf_st_7(plugin):
834+
def test_wf_3nd_st_1(plugin):
765835
""" workflow with three tasks, third one connected to two previous tasks,
766836
splitter on the workflow level
767837
"""
@@ -789,7 +859,7 @@ def test_wf_st_7(plugin):
789859

790860

791861
@pytest.mark.parametrize("plugin", Plugins)
792-
def test_wf_ndst_7(plugin):
862+
def test_wf_3nd_ndst_1(plugin):
793863
""" workflow with three tasks, third one connected to two previous tasks,
794864
splitter on the tasks levels
795865
"""
@@ -813,7 +883,7 @@ def test_wf_ndst_7(plugin):
813883

814884

815885
@pytest.mark.parametrize("plugin", Plugins)
816-
def test_wf_st_8(plugin):
886+
def test_wf_3nd_st_2(plugin):
817887
""" workflow with three tasks, third one connected to two previous tasks,
818888
splitter and partial combiner on the workflow level
819889
"""
@@ -844,7 +914,7 @@ def test_wf_st_8(plugin):
844914

845915

846916
@pytest.mark.parametrize("plugin", Plugins)
847-
def test_wf_ndst_8(plugin):
917+
def test_wf_3nd_ndst_2(plugin):
848918
""" workflow with three tasks, third one connected to two previous tasks,
849919
splitter and partial combiner on the tasks levels
850920
"""
@@ -873,7 +943,7 @@ def test_wf_ndst_8(plugin):
873943

874944

875945
@pytest.mark.parametrize("plugin", Plugins)
876-
def test_wf_st_9(plugin):
946+
def test_wf_3nd_st_3(plugin):
877947
""" workflow with three tasks, third one connected to two previous tasks,
878948
splitter and partial combiner (from the second task) on the workflow level
879949
"""
@@ -904,7 +974,7 @@ def test_wf_st_9(plugin):
904974

905975

906976
@pytest.mark.parametrize("plugin", Plugins)
907-
def test_wf_ndst_9(plugin):
977+
def test_wf_3nd_ndst_3(plugin):
908978
""" workflow with three tasks, third one connected to two previous tasks,
909979
splitter and partial combiner (from the second task) on the tasks levels
910980
"""
@@ -934,7 +1004,7 @@ def test_wf_ndst_9(plugin):
9341004

9351005

9361006
@pytest.mark.parametrize("plugin", Plugins)
937-
def test_wf_st_10(plugin):
1007+
def test_wf_3nd_st_4(plugin):
9381008
""" workflow with three tasks, third one connected to two previous tasks,
9391009
splitter and full combiner on the workflow level
9401010
"""
@@ -964,7 +1034,7 @@ def test_wf_st_10(plugin):
9641034

9651035

9661036
@pytest.mark.parametrize("plugin", Plugins)
967-
def test_wf_ndst_10(plugin):
1037+
def test_wf_3nd_ndst_4(plugin):
9681038
""" workflow with three tasks, third one connected to two previous tasks,
9691039
splitter and full combiner on the tasks levels
9701040
"""

0 commit comments

Comments
 (0)