Skip to content

Commit e88db6f

Browse files
committed
adding task_rerun to the task attributes; for wf it doesn't propagate to all the nodes, but it starts rerunning the workflow
1 parent 632fb9c commit e88db6f

File tree

3 files changed

+138
-3
lines changed

3 files changed

+138
-3
lines changed

pydra/engine/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
inputs: ty.Union[ty.Text, File, ty.Dict, None] = None,
8585
messenger_args=None,
8686
messengers=None,
87+
rerun=False,
8788
):
8889
"""
8990
Initialize a task.
@@ -173,6 +174,8 @@ def __init__(
173174
self.cache_locations = cache_locations
174175
self.allow_cache_override = True
175176
self._checksum = None
177+
# if True the results are not checked (does not propagate to nodes)
178+
self.task_rerun = rerun
176179

177180
self.plugin = None
178181
self.hooks = TaskHook()
@@ -365,7 +368,7 @@ def _run(self, rerun=False, **kwargs):
365368
self.hooks.pre_run(self)
366369
# TODO add signal handler for processes killed after lock acquisition
367370
with SoftFileLock(lockfile):
368-
if not rerun:
371+
if not (rerun or self.task_rerun):
369372
result = self.result()
370373
if result is not None:
371374
return result
@@ -610,6 +613,7 @@ def __init__(
610613
messenger_args=None,
611614
messengers=None,
612615
output_spec: ty.Optional[BaseSpec] = None,
616+
rerun=False,
613617
**kwargs,
614618
):
615619
"""
@@ -671,6 +675,7 @@ def __init__(
671675
audit_flags=audit_flags,
672676
messengers=messengers,
673677
messenger_args=messenger_args,
678+
rerun=rerun,
674679
)
675680

676681
self.graph = DiGraph()
@@ -789,7 +794,7 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
789794
checksum = self.checksum
790795
lockfile = self.cache_dir / (checksum + ".lock")
791796
# Eagerly retrieve cached
792-
if not rerun:
797+
if not (rerun or self.task_rerun):
793798
result = self.result()
794799
if result is not None:
795800
return result

pydra/engine/task.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
messengers=None,
7676
name=None,
7777
output_spec: ty.Optional[BaseSpec] = None,
78+
rerun=False,
7879
**kwargs,
7980
):
8081
"""
@@ -143,6 +144,7 @@ def __init__(
143144
messenger_args=messenger_args,
144145
cache_dir=cache_dir,
145146
cache_locations=cache_locations,
147+
rerun=rerun,
146148
)
147149
if output_spec is None:
148150
if "return" not in func.__annotations__:
@@ -242,6 +244,7 @@ def __init__(
242244
messengers=None,
243245
name=None,
244246
output_spec: ty.Optional[SpecInfo] = None,
247+
rerun=False,
245248
strip=False,
246249
**kwargs,
247250
):
@@ -283,6 +286,7 @@ def __init__(
283286
messengers=messengers,
284287
messenger_args=messenger_args,
285288
cache_dir=cache_dir,
289+
rerun=rerun,
286290
)
287291
self.strip = strip
288292

@@ -412,6 +416,7 @@ def __init__(
412416
messengers=None,
413417
output_cpath="/output_pydra",
414418
output_spec: ty.Optional[SpecInfo] = None,
419+
rerun=False,
415420
strip=False,
416421
**kwargs,
417422
):
@@ -452,6 +457,7 @@ def __init__(
452457
messenger_args=messenger_args,
453458
cache_dir=cache_dir,
454459
strip=strip,
460+
rerun=rerun,
455461
**kwargs,
456462
)
457463

@@ -523,6 +529,7 @@ def __init__(
523529
messengers=None,
524530
output_cpath="/output_pydra",
525531
output_spec: ty.Optional[SpecInfo] = None,
532+
rerun=False,
526533
strip=False,
527534
**kwargs,
528535
):
@@ -565,6 +572,7 @@ def __init__(
565572
cache_dir=cache_dir,
566573
strip=strip,
567574
output_cpath=output_cpath,
575+
rerun=rerun,
568576
**kwargs,
569577
)
570578
self.inputs.container_xargs = ["--rm"]
@@ -619,6 +627,7 @@ def __init__(
619627
messenger_args=None,
620628
messengers=None,
621629
output_spec: ty.Optional[SpecInfo] = None,
630+
rerun=False,
622631
strip=False,
623632
**kwargs,
624633
):
@@ -658,6 +667,7 @@ def __init__(
658667
messenger_args=messenger_args,
659668
cache_dir=cache_dir,
660669
strip=strip,
670+
rerun=rerun,
661671
**kwargs,
662672
)
663673
self.init = True

pydra/engine/tests/test_workflow.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,126 @@ def test_wf_nostate_cachelocations_forcererun(plugin, tmpdir):
19881988
assert wf2.output_dir.exists()
19891989

19901990

1991+
@pytest.mark.parametrize("plugin", Plugins)
1992+
def test_wf_nostate_cachelocations_wftaskrerun(plugin, tmpdir):
1993+
"""
1994+
Two identical wfs with provided cache_dir;
1995+
the second wf has cache_locations and rerun=True,
1996+
so the workflow is rerun, but it doesn't propagate to nodes,
1997+
so none of the node has to be recalculated
1998+
"""
1999+
cache_dir1 = tmpdir.mkdir("test_wf_cache3")
2000+
cache_dir2 = tmpdir.mkdir("test_wf_cache4")
2001+
2002+
wf1 = Workflow(name="wf", input_spec=["x", "y"], cache_dir=cache_dir1)
2003+
wf1.add(multiply(name="mult", x=wf1.lzin.x, y=wf1.lzin.y))
2004+
wf1.add(add2_wait(name="add2", x=wf1.mult.lzout.out))
2005+
wf1.set_output([("out", wf1.add2.lzout.out)])
2006+
wf1.inputs.x = 2
2007+
wf1.inputs.y = 3
2008+
wf1.plugin = plugin
2009+
2010+
t0 = time.time()
2011+
with Submitter(plugin=plugin) as sub:
2012+
sub(wf1)
2013+
t1 = time.time() - t0
2014+
2015+
results1 = wf1.result()
2016+
assert 8 == results1.output.out
2017+
2018+
wf2 = Workflow(
2019+
name="wf",
2020+
input_spec=["x", "y"],
2021+
cache_dir=cache_dir2,
2022+
cache_locations=cache_dir1,
2023+
rerun=True,
2024+
)
2025+
wf2.add(multiply(name="mult", x=wf2.lzin.x, y=wf2.lzin.y))
2026+
wf2.add(add2_wait(name="add2", x=wf2.mult.lzout.out))
2027+
wf2.set_output([("out", wf2.add2.lzout.out)])
2028+
wf2.inputs.x = 2
2029+
wf2.inputs.y = 3
2030+
wf2.plugin = plugin
2031+
2032+
t0 = time.time()
2033+
with Submitter(plugin=plugin) as sub:
2034+
sub(wf2)
2035+
t2 = time.time() - t0
2036+
2037+
results2 = wf2.result()
2038+
assert 8 == results2.output.out
2039+
2040+
# checking if the second wf runs again
2041+
assert wf1.output_dir.exists()
2042+
assert wf2.output_dir.exists()
2043+
2044+
# even if the second wf is recomputed the nodes are not, so it's fast
2045+
assert len(list(Path(cache_dir1).glob("F*"))) == 2
2046+
assert len(list(Path(cache_dir2).glob("F*"))) == 0
2047+
assert t1 > 3
2048+
assert t2 < 1
2049+
2050+
2051+
@pytest.mark.parametrize("plugin", Plugins)
2052+
def test_wf_nostate_cachelocations_taskrerun(plugin, tmpdir):
2053+
"""
2054+
Two identical wfs with provided cache_dir;
2055+
the second wf has cache_locations and rerun=True
2056+
and one of the task also has rerun=True;
2057+
so the workflow is rerun, and one of the node also has to be rerun,
2058+
"""
2059+
cache_dir1 = tmpdir.mkdir("test_wf_cache3")
2060+
cache_dir2 = tmpdir.mkdir("test_wf_cache4")
2061+
2062+
wf1 = Workflow(name="wf", input_spec=["x", "y"], cache_dir=cache_dir1)
2063+
wf1.add(multiply(name="mult", x=wf1.lzin.x, y=wf1.lzin.y))
2064+
wf1.add(add2_wait(name="add2", x=wf1.mult.lzout.out))
2065+
wf1.set_output([("out", wf1.add2.lzout.out)])
2066+
wf1.inputs.x = 2
2067+
wf1.inputs.y = 3
2068+
wf1.plugin = plugin
2069+
2070+
t0 = time.time()
2071+
with Submitter(plugin=plugin) as sub:
2072+
sub(wf1)
2073+
t1 = time.time() - t0
2074+
2075+
results1 = wf1.result()
2076+
assert 8 == results1.output.out
2077+
2078+
wf2 = Workflow(
2079+
name="wf",
2080+
input_spec=["x", "y"],
2081+
cache_dir=cache_dir2,
2082+
cache_locations=cache_dir1,
2083+
rerun=True,
2084+
)
2085+
wf2.add(multiply(name="mult", x=wf2.lzin.x, y=wf2.lzin.y))
2086+
wf2.add(add2_wait(name="add2", x=wf2.mult.lzout.out, rerun=True))
2087+
wf2.set_output([("out", wf2.add2.lzout.out)])
2088+
wf2.inputs.x = 2
2089+
wf2.inputs.y = 3
2090+
wf2.plugin = plugin
2091+
2092+
t0 = time.time()
2093+
with Submitter(plugin=plugin) as sub:
2094+
sub(wf2)
2095+
t2 = time.time() - t0
2096+
2097+
results2 = wf2.result()
2098+
assert 8 == results2.output.out
2099+
2100+
# checking if the second wf runs again
2101+
assert wf1.output_dir.exists()
2102+
assert wf2.output_dir.exists()
2103+
2104+
# the second task also has to be recomputed this time
2105+
assert len(list(Path(cache_dir1).glob("F*"))) == 2
2106+
assert len(list(Path(cache_dir2).glob("F*"))) == 1
2107+
assert t1 > 3
2108+
assert t2 > 3
2109+
2110+
19912111
@pytest.mark.parametrize("plugin", Plugins)
19922112
def test_wf_nostate_nodecachelocations(plugin, tmpdir):
19932113
"""
@@ -2408,7 +2528,7 @@ def test_wf_nostate_cachelocations_recompute(plugin, tmpdir):
24082528

24092529
# checking execution time (second task shouldn't be recompute, t2 should be small)
24102530
assert t1 > 3
2411-
assert t2 < 0.5
2531+
assert t2 < 1
24122532

24132533

24142534
@pytest.mark.parametrize("plugin", Plugins)

0 commit comments

Comments
 (0)