Skip to content

Commit c57d830

Browse files
authored
Merge pull request #207 from djarecka/fix/wf_full_combiner
[fix] fixing workflows with first/intermediate tasks with a full combiner
2 parents a847b39 + a53be6f commit c57d830

File tree

4 files changed

+124
-46
lines changed

4 files changed

+124
-46
lines changed

pydra/engine/core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,11 @@ def _combined_output(self):
538538
if result is None:
539539
return None
540540
combined_results[gr].append(result)
541-
return combined_results
541+
if len(combined_results) == 1 and self.state.splitter_rpn_final == []:
542+
# in case it's full combiner, removing the nested structure
543+
return combined_results[0]
544+
else:
545+
return combined_results
542546

543547
def result(self, state_index=None):
544548
"""
@@ -753,7 +757,10 @@ def create_connections(self, task):
753757
self.graph.add_edges((getattr(self, val.name), task))
754758
logger.debug("Connecting %s to %s", val.name, task.name)
755759

756-
if getattr(self, val.name).state:
760+
if (
761+
getattr(self, val.name).state
762+
and getattr(self, val.name).state.splitter_rpn_final
763+
):
757764
# adding a state from the previous task to other_states
758765
other_states[val.name] = (
759766
getattr(self, val.name).state,

pydra/engine/tests/test_node_task.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,10 @@ def test_task_state_comb_1(plugin):
937937
# checking the results
938938
results = nn.result()
939939

940-
combined_results = [[res.output.out for res in res_l] for res_l in results]
941-
expected = [({}, [5, 7])]
942-
for i, res in enumerate(expected):
943-
assert combined_results[i] == res[1]
940+
# fully combined (no nested list)
941+
combined_results = [res.output.out for res in results]
942+
943+
assert combined_results == [5, 7]
944944
# checking the output_dir
945945
assert nn.output_dir
946946
for odir in nn.output_dir:
@@ -1072,11 +1072,11 @@ def test_task_state_comb_singl_1(plugin):
10721072
sub(nn)
10731073

10741074
# checking the results
1075-
expected = [({}, [13, 15])]
1075+
expected = ({}, [13, 15])
10761076
results = nn.result()
1077-
combined_results = [[res.output.out for res in res_l] for res_l in results]
1078-
for i, res in enumerate(expected):
1079-
assert combined_results[i] == res[1]
1077+
# full combiner, no nested list
1078+
combined_results = [res.output.out for res in results]
1079+
assert combined_results == expected[1]
10801080
# checking the output_dir
10811081
assert nn.output_dir
10821082
for odir in nn.output_dir:
@@ -1143,7 +1143,8 @@ def test_task_state_comb_order(plugin):
11431143
assert nn_ab.state.combiner == ["NA.a", "NA.b"]
11441144

11451145
results_ab = nn_ab()
1146-
combined_results_ab = [res.output.out for res in results_ab[0]]
1146+
# full combiner, no nested list
1147+
combined_results_ab = [res.output.out for res in results_ab]
11471148
assert combined_results_ab == [13, 15, 23, 25]
11481149

11491150
# combiner with both fields ["b", "a"] - will create the same list as nn_ab
@@ -1156,7 +1157,7 @@ def test_task_state_comb_order(plugin):
11561157
assert nn_ba.state.combiner == ["NA.b", "NA.a"]
11571158

11581159
results_ba = nn_ba()
1159-
combined_results_ba = [res.output.out for res in results_ba[0]]
1160+
combined_results_ba = [res.output.out for res in results_ba]
11601161
assert combined_results_ba == [13, 15, 23, 25]
11611162

11621163

pydra/engine/tests/test_shelltask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def test_shell_cmd_5(plugin):
174174
assert shelly.cmdline == ["echo nipype", "echo pydra"]
175175
res = shelly(plugin=plugin)
176176

177-
assert res[0][0].output.stdout == "nipype\n"
178-
assert res[0][1].output.stdout == "pydra\n"
177+
assert res[0].output.stdout == "nipype\n"
178+
assert res[1].output.stdout == "pydra\n"
179179

180180

181181
@pytest.mark.parametrize("plugin", Plugins)

pydra/engine/tests/test_workflow.py

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,9 @@ def test_wf_st_2(plugin):
532532
sub(wf)
533533

534534
results = wf.result()
535-
# expected: [[({"test7.x": 1}, 3), ({"test7.x": 2}, 4)]]
536-
assert results[0][0].output.out == 3
537-
assert results[0][1].output.out == 4
535+
# expected: [({"test7.x": 1}, 3), ({"test7.x": 2}, 4)]
536+
assert results[0].output.out == 3
537+
assert results[1].output.out == 4
538538
# checking all directories
539539
assert wf.output_dir
540540
for odir in wf.output_dir:
@@ -554,8 +554,8 @@ def test_wf_ndst_2(plugin):
554554
sub(wf)
555555

556556
results = wf.result()
557-
# expected: [[({"test7.x": 1}, 3), ({"test7.x": 2}, 4)]]
558-
assert results.output.out[0] == [3, 4]
557+
# expected: [({"test7.x": 1}, 3), ({"test7.x": 2}, 4)]
558+
assert results.output.out == [3, 4]
559559
assert wf.output_dir.exists()
560560

561561

@@ -625,10 +625,10 @@ def test_wf_st_4(plugin):
625625

626626
results = wf.result()
627627
# expected: [
628-
# [({"test7.x": 1, "test7.y": 11}, 13), ({"test7.x": 2, "test.y": 12}, 26)]
628+
# ({"test7.x": 1, "test7.y": 11}, 13), ({"test7.x": 2, "test.y": 12}, 26)
629629
# ]
630-
assert results[0][0].output.out == 13
631-
assert results[0][1].output.out == 26
630+
assert results[0].output.out == 13
631+
assert results[1].output.out == 26
632632
# checking all directories
633633
assert wf.output_dir
634634
for odir in wf.output_dir:
@@ -652,9 +652,9 @@ def test_wf_ndst_4(plugin):
652652

653653
results = wf.result()
654654
# expected: [
655-
# [({"test7.x": 1, "test7.y": 11}, 13), ({"test7.x": 2, "test.y": 12}, 26)]
655+
# ({"test7.x": 1, "test7.y": 11}, 13), ({"test7.x": 2, "test.y": 12}, 26)
656656
# ]
657-
assert results.output.out[0] == [13, 26]
657+
assert results.output.out == [13, 26]
658658
# checking the output directory
659659
assert wf.output_dir.exists()
660660

@@ -757,11 +757,81 @@ def test_wf_ndst_6(plugin):
757757
assert wf.output_dir.exists()
758758

759759

760+
@pytest.mark.parametrize("plugin", Plugins)
761+
def test_wf_ndst_7(plugin):
762+
""" workflow with two tasks, outer splitter and (full) combiner for first node only"""
763+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
764+
wf.add(multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y).split("x").combine("x"))
765+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
766+
wf.inputs.x = [1, 2, 3]
767+
wf.inputs.y = 11
768+
wf.set_output([("out", wf.iden.lzout.out)])
769+
wf.plugin = plugin
770+
771+
with Submitter(plugin=plugin) as sub:
772+
sub(wf)
773+
774+
results = wf.result()
775+
assert results.output.out == [11, 22, 33]
776+
777+
# checking the output directory
778+
assert wf.output_dir.exists()
779+
780+
781+
@pytest.mark.parametrize("plugin", Plugins)
782+
def test_wf_ndst_8(plugin):
783+
""" workflow with two tasks, outer splitter and (partial) combiner for first task only"""
784+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
785+
wf.add(
786+
multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y).split(["x", "y"]).combine("x")
787+
)
788+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
789+
wf.inputs.x = [1, 2, 3]
790+
wf.inputs.y = [11, 12]
791+
wf.set_output([("out", wf.iden.lzout.out)])
792+
wf.plugin = plugin
793+
794+
with Submitter(plugin=plugin) as sub:
795+
sub(wf)
796+
797+
results = wf.result()
798+
assert results.output.out[0] == [11, 22, 33]
799+
assert results.output.out[1] == [12, 24, 36]
800+
801+
# checking the output directory
802+
assert wf.output_dir.exists()
803+
804+
805+
@pytest.mark.parametrize("plugin", Plugins)
806+
def test_wf_ndst_9(plugin):
807+
""" workflow with two tasks, outer splitter and (full) combiner for first task only"""
808+
wf = Workflow(name="wf_ndst_6", input_spec=["x", "y"])
809+
wf.add(
810+
multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y)
811+
.split(["x", "y"])
812+
.combine(["x", "y"])
813+
)
814+
wf.add(identity(name="iden", x=wf.mult.lzout.out))
815+
wf.inputs.x = [1, 2, 3]
816+
wf.inputs.y = [11, 12]
817+
wf.set_output([("out", wf.iden.lzout.out)])
818+
wf.plugin = plugin
819+
820+
with Submitter(plugin=plugin) as sub:
821+
sub(wf)
822+
823+
results = wf.result()
824+
assert results.output.out == [11, 12, 22, 24, 33, 36]
825+
826+
# checking the output directory
827+
assert wf.output_dir.exists()
828+
829+
760830
# workflows with structures A -> C, B -> C
761831

762832

763833
@pytest.mark.parametrize("plugin", Plugins)
764-
def test_wf_st_7(plugin):
834+
def test_wf_3nd_st_1(plugin):
765835
""" workflow with three tasks, third one connected to two previous tasks,
766836
splitter on the workflow level
767837
"""
@@ -789,7 +859,7 @@ def test_wf_st_7(plugin):
789859

790860

791861
@pytest.mark.parametrize("plugin", Plugins)
792-
def test_wf_ndst_7(plugin):
862+
def test_wf_3nd_ndst_1(plugin):
793863
""" workflow with three tasks, third one connected to two previous tasks,
794864
splitter on the tasks levels
795865
"""
@@ -813,7 +883,7 @@ def test_wf_ndst_7(plugin):
813883

814884

815885
@pytest.mark.parametrize("plugin", Plugins)
816-
def test_wf_st_8(plugin):
886+
def test_wf_3nd_st_2(plugin):
817887
""" workflow with three tasks, third one connected to two previous tasks,
818888
splitter and partial combiner on the workflow level
819889
"""
@@ -844,7 +914,7 @@ def test_wf_st_8(plugin):
844914

845915

846916
@pytest.mark.parametrize("plugin", Plugins)
847-
def test_wf_ndst_8(plugin):
917+
def test_wf_3nd_ndst_2(plugin):
848918
""" workflow with three tasks, third one connected to two previous tasks,
849919
splitter and partial combiner on the tasks levels
850920
"""
@@ -873,7 +943,7 @@ def test_wf_ndst_8(plugin):
873943

874944

875945
@pytest.mark.parametrize("plugin", Plugins)
876-
def test_wf_st_9(plugin):
946+
def test_wf_3nd_st_3(plugin):
877947
""" workflow with three tasks, third one connected to two previous tasks,
878948
splitter and partial combiner (from the second task) on the workflow level
879949
"""
@@ -904,7 +974,7 @@ def test_wf_st_9(plugin):
904974

905975

906976
@pytest.mark.parametrize("plugin", Plugins)
907-
def test_wf_ndst_9(plugin):
977+
def test_wf_3nd_ndst_3(plugin):
908978
""" workflow with three tasks, third one connected to two previous tasks,
909979
splitter and partial combiner (from the second task) on the tasks levels
910980
"""
@@ -934,7 +1004,7 @@ def test_wf_ndst_9(plugin):
9341004

9351005

9361006
@pytest.mark.parametrize("plugin", Plugins)
937-
def test_wf_st_10(plugin):
1007+
def test_wf_3nd_st_4(plugin):
9381008
""" workflow with three tasks, third one connected to two previous tasks,
9391009
splitter and full combiner on the workflow level
9401010
"""
@@ -950,21 +1020,21 @@ def test_wf_st_10(plugin):
9501020
sub(wf)
9511021

9521022
results = wf.result()
953-
assert len(results) == 1
954-
assert results[0][0].output.out == 39
955-
assert results[0][1].output.out == 42
956-
assert results[0][2].output.out == 52
957-
assert results[0][3].output.out == 56
958-
assert results[0][4].output.out == 65
959-
assert results[0][5].output.out == 70
1023+
assert len(results) == 6
1024+
assert results[0].output.out == 39
1025+
assert results[1].output.out == 42
1026+
assert results[2].output.out == 52
1027+
assert results[3].output.out == 56
1028+
assert results[4].output.out == 65
1029+
assert results[5].output.out == 70
9601030
# checking all directories
9611031
assert wf.output_dir
9621032
for odir in wf.output_dir:
9631033
assert odir.exists()
9641034

9651035

9661036
@pytest.mark.parametrize("plugin", Plugins)
967-
def test_wf_ndst_10(plugin):
1037+
def test_wf_3nd_ndst_4(plugin):
9681038
""" workflow with three tasks, third one connected to two previous tasks,
9691039
splitter and full combiner on the tasks levels
9701040
"""
@@ -986,8 +1056,8 @@ def test_wf_ndst_10(plugin):
9861056
# assert wf.output_dir.exists()
9871057
results = wf.result()
9881058

989-
assert len(results.output.out) == 1
990-
assert results.output.out == [[39, 42, 52, 56, 65, 70]]
1059+
assert len(results.output.out) == 6
1060+
assert results.output.out == [39, 42, 52, 56, 65, 70]
9911061
# checking the output directory
9921062
assert wf.output_dir.exists()
9931063

@@ -1284,8 +1354,8 @@ def test_wf_st_singl_1(plugin):
12841354
sub(wf)
12851355

12861356
results = wf.result()
1287-
assert results[0][0].output.out == 13
1288-
assert results[0][1].output.out == 24
1357+
assert results[0].output.out == 13
1358+
assert results[1].output.out == 24
12891359
# checking all directories
12901360
assert wf.output_dir
12911361
for odir in wf.output_dir:
@@ -1309,7 +1379,7 @@ def test_wf_ndst_singl_1(plugin):
13091379
sub(wf)
13101380

13111381
results = wf.result()
1312-
assert results.output.out[0] == [13, 24]
1382+
assert results.output.out == [13, 24]
13131383
# checking the output directory
13141384
assert wf.output_dir.exists()
13151385

@@ -2665,7 +2735,7 @@ def test_workflow_combine1(tmpdir):
26652735

26662736
assert result.output.out_pow == [1, 1, 4, 8]
26672737
assert result.output.out_iden1 == [[1, 4], [1, 8]]
2668-
assert result.output.out_iden2 == [[[1, 4], [1, 8]]]
2738+
assert result.output.out_iden2 == [[1, 4], [1, 8]]
26692739

26702740

26712741
def test_workflow_combine2(tmpdir):
@@ -2679,7 +2749,7 @@ def test_workflow_combine2(tmpdir):
26792749
result = wf1(plugin="cf")
26802750

26812751
assert result.output.out_pow == [[1, 4], [1, 8]]
2682-
assert result.output.out_iden == [[[1, 4], [1, 8]]]
2752+
assert result.output.out_iden == [[1, 4], [1, 8]]
26832753

26842754

26852755
# testing lzout.all to collect all of the results and let FunctionTask deal with it

0 commit comments

Comments
 (0)