Skip to content

Commit 0548de1

Browse files
committed
adding option to run set_output twice in order to update connections and wf outputs; adding tests
1 parent adb9c27 commit 0548de1

File tree

3 files changed

+160
-3
lines changed

3 files changed

+160
-3
lines changed

pydra/engine/core.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,18 +927,29 @@ def set_output(self, connections):
927927
TODO
928928
929929
"""
930+
if self._connections is None:
931+
self._connections = []
930932
if isinstance(connections, tuple) and len(connections) == 2:
931-
self._connections = [connections]
933+
new_connections = [connections]
932934
elif isinstance(connections, list) and all(
933935
[len(el) == 2 for el in connections]
934936
):
935-
self._connections = connections
937+
new_connections = connections
936938
elif isinstance(connections, dict):
937-
self._connections = list(connections.items())
939+
new_connections = list(connections.items())
938940
else:
939941
raise Exception(
940942
"Connections can be a 2-elements tuple, a list of these tuples, or dictionary"
941943
)
944+
# checking if a new output name is already in the connections
945+
connection_names = [name for name, _ in self._connections]
946+
new_names = [name for name, _ in new_connections]
947+
if set(connection_names).intersection(new_names):
948+
raise Exception(
949+
f"output name {set(connection_names).intersection(new_names)} is already set"
950+
)
951+
952+
self._connections += new_connections
942953
fields = [(name, ty.Any) for name, _ in self._connections]
943954
self.output_spec = SpecInfo(name="Output", fields=fields, bases=(BaseSpec,))
944955
logger.info("Added %s to %s", self.output_spec, self)

pydra/engine/tests/test_workflow.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ten,
1313
identity,
1414
list_output,
15+
fun_addsubvar,
1516
fun_addvar3,
1617
add2_sub2_res,
1718
fun_addvar_none,
@@ -291,6 +292,87 @@ def test_wf_4a(plugin):
291292
assert 5 == results.output.out
292293

293294

295+
def test_wf_5(plugin):
296+
""" wf with two outputs connected to the task outputs
297+
one set_output
298+
"""
299+
wf = Workflow(name="wf_5", input_spec=["x", "y"], x=3, y=2)
300+
wf.add(fun_addsubvar(name="addsub", a=wf.lzin.x, b=wf.lzin.y))
301+
wf.set_output([("out_sum", wf.addsub.lzout.sum), ("out_sub", wf.addsub.lzout.sub)])
302+
303+
with Submitter(plugin=plugin) as sub:
304+
sub(wf)
305+
306+
results = wf.result()
307+
assert 5 == results.output.out_sum
308+
assert 1 == results.output.out_sub
309+
310+
311+
def test_wf_5a(plugin):
312+
""" wf with two outputs connected to the task outputs,
313+
set_output set twice
314+
"""
315+
wf = Workflow(name="wf_5", input_spec=["x", "y"], x=3, y=2)
316+
wf.add(fun_addsubvar(name="addsub", a=wf.lzin.x, b=wf.lzin.y))
317+
wf.set_output([("out_sum", wf.addsub.lzout.sum)])
318+
wf.set_output([("out_sub", wf.addsub.lzout.sub)])
319+
320+
with Submitter(plugin=plugin) as sub:
321+
sub(wf)
322+
323+
results = wf.result()
324+
assert 5 == results.output.out_sum
325+
assert 1 == results.output.out_sub
326+
327+
328+
def test_wf_5b_exception():
329+
""" set_output used twice with the same name - exception should be raised """
330+
wf = Workflow(name="wf_5", input_spec=["x", "y"], x=3, y=2)
331+
wf.add(fun_addsubvar(name="addsub", a=wf.lzin.x, b=wf.lzin.y))
332+
wf.set_output([("out", wf.addsub.lzout.sum)])
333+
334+
with pytest.raises(Exception) as excinfo:
335+
wf.set_output([("out", wf.addsub.lzout.sub)])
336+
assert "is already set" in str(excinfo.value)
337+
338+
339+
def test_wf_6(plugin):
340+
""" wf with two tasks and two outputs connected to both tasks,
341+
one set_output
342+
"""
343+
wf = Workflow(name="wf_6", input_spec=["x", "y"], x=2, y=3)
344+
wf.add(multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y))
345+
wf.add(add2(name="add2", x=wf.mult.lzout.out))
346+
wf.set_output([("out1", wf.mult.lzout.out), ("out2", wf.add2.lzout.out)])
347+
348+
with Submitter(plugin=plugin) as sub:
349+
sub(wf)
350+
351+
assert wf.output_dir.exists()
352+
results = wf.result()
353+
assert 6 == results.output.out1
354+
assert 8 == results.output.out2
355+
356+
357+
def test_wf_6a(plugin):
358+
""" wf with two tasks and two outputs connected to both tasks,
359+
set_output used twice
360+
"""
361+
wf = Workflow(name="wf_6", input_spec=["x", "y"], x=2, y=3)
362+
wf.add(multiply(name="mult", x=wf.lzin.x, y=wf.lzin.y))
363+
wf.add(add2(name="add2", x=wf.mult.lzout.out))
364+
wf.set_output([("out1", wf.mult.lzout.out)])
365+
wf.set_output([("out2", wf.add2.lzout.out)])
366+
367+
with Submitter(plugin=plugin) as sub:
368+
sub(wf)
369+
370+
assert wf.output_dir.exists()
371+
results = wf.result()
372+
assert 6 == results.output.out1
373+
assert 8 == results.output.out2
374+
375+
294376
def test_wf_st_1(plugin):
295377
""" Workflow with one task, a splitter for the workflow"""
296378
wf = Workflow(name="wf_spl_1", input_spec=["x"])
@@ -2055,6 +2137,64 @@ def test_wf_nostate_cachelocations_a(plugin, tmpdir):
20552137
assert wf2.output_dir.exists()
20562138

20572139

2140+
def test_wf_nostate_cachelocations_b(plugin, tmpdir):
2141+
"""
2142+
the same as previous test, but the 2nd workflows has two outputs
2143+
(connected to the same task output);
2144+
the task should not be run and it should be fast,
2145+
but the wf itself is triggered and the new output dir is created
2146+
"""
2147+
cache_dir1 = tmpdir.mkdir("test_wf_cache3")
2148+
cache_dir2 = tmpdir.mkdir("test_wf_cache4")
2149+
2150+
wf1 = Workflow(name="wf", input_spec=["x", "y"], cache_dir=cache_dir1)
2151+
wf1.add(multiply(name="mult", x=wf1.lzin.x, y=wf1.lzin.y))
2152+
wf1.add(add2_wait(name="add2", x=wf1.mult.lzout.out))
2153+
wf1.set_output([("out", wf1.add2.lzout.out)])
2154+
wf1.inputs.x = 2
2155+
wf1.inputs.y = 3
2156+
wf1.plugin = plugin
2157+
2158+
t0 = time.time()
2159+
with Submitter(plugin=plugin) as sub:
2160+
sub(wf1)
2161+
t1 = time.time() - t0
2162+
2163+
results1 = wf1.result()
2164+
assert 8 == results1.output.out
2165+
2166+
wf2 = Workflow(
2167+
name="wf",
2168+
input_spec=["x", "y"],
2169+
cache_dir=cache_dir2,
2170+
cache_locations=cache_dir1,
2171+
)
2172+
wf2.add(multiply(name="mult", x=wf2.lzin.x, y=wf2.lzin.y))
2173+
wf2.add(add2_wait(name="add2", x=wf2.mult.lzout.out))
2174+
wf2.set_output([("out", wf2.add2.lzout.out)])
2175+
# additional output
2176+
wf2.set_output([("out_pr", wf2.add2.lzout.out)])
2177+
wf2.inputs.x = 2
2178+
wf2.inputs.y = 3
2179+
wf2.plugin = plugin
2180+
2181+
t0 = time.time()
2182+
with Submitter(plugin=plugin) as sub:
2183+
sub(wf2)
2184+
t2 = time.time() - t0
2185+
2186+
results2 = wf2.result()
2187+
assert 8 == results2.output.out == results2.output.out_pr
2188+
2189+
# checking execution time
2190+
assert t1 > 3
2191+
assert t2 / t1 < 0.9
2192+
2193+
# checking if the second wf didn't run again
2194+
assert wf1.output_dir.exists()
2195+
assert wf2.output_dir.exists()
2196+
2197+
20582198
@pytest.mark.flaky(reruns=2) # windows test sometimes gives longer time t2
20592199
def test_wf_nostate_cachelocations_setoutputchange(plugin, tmpdir):
20602200
"""

pydra/engine/tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def fun_addvar(a, b):
5151
return a + b
5252

5353

54+
@mark.task
55+
@mark.annotate({"return": {"sum": float, "sub": float}})
56+
def fun_addsubvar(a, b):
57+
return a + b, a - b
58+
59+
5460
@mark.task
5561
def fun_addvar_none(a, b):
5662
if b is None:

0 commit comments

Comments
 (0)