Skip to content

Commit 73ac982

Browse files
authored
Merge pull request #208 from djarecka/fix/usecache_nodelevel
[fix] passing wf.cache_locations and submitter.rerun
2 parents c57d830 + 5f71205 commit 73ac982

File tree

4 files changed

+325
-10
lines changed

4 files changed

+325
-10
lines changed

pydra/engine/core.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
inputs: ty.Union[ty.Text, File, ty.Dict, None] = None,
8585
messenger_args=None,
8686
messengers=None,
87+
rerun=False,
8788
):
8889
"""
8990
Initialize a task.
@@ -173,6 +174,8 @@ def __init__(
173174
self.cache_locations = cache_locations
174175
self.allow_cache_override = True
175176
self._checksum = None
177+
# if True the results are not checked (does not propagate to nodes)
178+
self.task_rerun = rerun
176179

177180
self.plugin = None
178181
self.hooks = TaskHook()
@@ -365,7 +368,7 @@ def _run(self, rerun=False, **kwargs):
365368
self.hooks.pre_run(self)
366369
# TODO add signal handler for processes killed after lock acquisition
367370
with SoftFileLock(lockfile):
368-
if not rerun:
371+
if not (rerun or self.task_rerun):
369372
result = self.result()
370373
if result is not None:
371374
return result
@@ -610,6 +613,8 @@ def __init__(
610613
messenger_args=None,
611614
messengers=None,
612615
output_spec: ty.Optional[BaseSpec] = None,
616+
rerun=False,
617+
propagate_rerun=True,
613618
**kwargs,
614619
):
615620
"""
@@ -671,13 +676,16 @@ def __init__(
671676
audit_flags=audit_flags,
672677
messengers=messengers,
673678
messenger_args=messenger_args,
679+
rerun=rerun,
674680
)
675681

676682
self.graph = DiGraph()
677683
self.name2obj = {}
678684

679685
# store output connections
680686
self._connections = None
687+
# propagating rerun if task_rerun=True
688+
self.propagate_rerun = propagate_rerun
681689

682690
def __getattr__(self, name):
683691
if name == "lzin":
@@ -789,12 +797,20 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
789797
checksum = self.checksum
790798
lockfile = self.cache_dir / (checksum + ".lock")
791799
# Eagerly retrieve cached
792-
if not rerun:
800+
if not (rerun or self.task_rerun):
793801
result = self.result()
794802
if result is not None:
795803
return result
796804
# creating connections that were defined after adding tasks to the wf
797805
for task in self.graph.nodes:
806+
# if workflow has task_rerun=True and propagate_rerun=True,
807+
# it should be passed to the tasks
808+
if self.task_rerun and self.propagate_rerun:
809+
task.task_rerun = self.task_rerun
810+
# if the task is a wf, than the propagate_rerun should be also set
811+
if is_workflow(task):
812+
task.propagate_rerun = self.propagate_rerun
813+
task.cache_locations = task._cache_locations + self.cache_locations
798814
self.create_connections(task)
799815
# TODO add signal handler for processes killed after lock acquisition
800816
self.hooks.pre_run(self)
@@ -810,7 +826,7 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
810826
self.hooks.pre_run_task(self)
811827
try:
812828
self.audit.monitor()
813-
await self._run_task(submitter)
829+
await self._run_task(submitter, rerun=rerun)
814830
result.output = self._collect_outputs()
815831
except Exception as e:
816832
record_error(self.output_dir, e)
@@ -824,11 +840,11 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
824840
self.hooks.post_run(self, result)
825841
return result
826842

827-
async def _run_task(self, submitter):
843+
async def _run_task(self, submitter, rerun=False):
828844
if not submitter:
829845
raise Exception("Submitter should already be set.")
830846
# at this point Workflow is stateless so this should be fine
831-
await submitter._run_workflow(self)
847+
await submitter._run_workflow(self, rerun=rerun)
832848

833849
def set_output(self, connections):
834850
"""

pydra/engine/task.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
messengers=None,
7676
name=None,
7777
output_spec: ty.Optional[BaseSpec] = None,
78+
rerun=False,
7879
**kwargs,
7980
):
8081
"""
@@ -143,6 +144,7 @@ def __init__(
143144
messenger_args=messenger_args,
144145
cache_dir=cache_dir,
145146
cache_locations=cache_locations,
147+
rerun=rerun,
146148
)
147149
if output_spec is None:
148150
if "return" not in func.__annotations__:
@@ -242,6 +244,7 @@ def __init__(
242244
messengers=None,
243245
name=None,
244246
output_spec: ty.Optional[SpecInfo] = None,
247+
rerun=False,
245248
strip=False,
246249
**kwargs,
247250
):
@@ -283,6 +286,7 @@ def __init__(
283286
messengers=messengers,
284287
messenger_args=messenger_args,
285288
cache_dir=cache_dir,
289+
rerun=rerun,
286290
)
287291
self.strip = strip
288292

@@ -412,6 +416,7 @@ def __init__(
412416
messengers=None,
413417
output_cpath="/output_pydra",
414418
output_spec: ty.Optional[SpecInfo] = None,
419+
rerun=False,
415420
strip=False,
416421
**kwargs,
417422
):
@@ -452,6 +457,7 @@ def __init__(
452457
messenger_args=messenger_args,
453458
cache_dir=cache_dir,
454459
strip=strip,
460+
rerun=rerun,
455461
**kwargs,
456462
)
457463

@@ -523,6 +529,7 @@ def __init__(
523529
messengers=None,
524530
output_cpath="/output_pydra",
525531
output_spec: ty.Optional[SpecInfo] = None,
532+
rerun=False,
526533
strip=False,
527534
**kwargs,
528535
):
@@ -565,6 +572,7 @@ def __init__(
565572
cache_dir=cache_dir,
566573
strip=strip,
567574
output_cpath=output_cpath,
575+
rerun=rerun,
568576
**kwargs,
569577
)
570578
self.inputs.container_xargs = ["--rm"]
@@ -619,6 +627,7 @@ def __init__(
619627
messenger_args=None,
620628
messengers=None,
621629
output_spec: ty.Optional[SpecInfo] = None,
630+
rerun=False,
622631
strip=False,
623632
**kwargs,
624633
):
@@ -658,6 +667,7 @@ def __init__(
658667
messenger_args=messenger_args,
659668
cache_dir=cache_dir,
660669
strip=strip,
670+
rerun=rerun,
661671
**kwargs,
662672
)
663673
self.init = True

0 commit comments

Comments
 (0)