@@ -84,6 +84,7 @@ def __init__(
84
84
inputs : ty .Union [ty .Text , File , ty .Dict , None ] = None ,
85
85
messenger_args = None ,
86
86
messengers = None ,
87
+ rerun = False ,
87
88
):
88
89
"""
89
90
Initialize a task.
@@ -173,6 +174,8 @@ def __init__(
173
174
self .cache_locations = cache_locations
174
175
self .allow_cache_override = True
175
176
self ._checksum = None
177
+ # if True the results are not checked (does not propagate to nodes)
178
+ self .task_rerun = rerun
176
179
177
180
self .plugin = None
178
181
self .hooks = TaskHook ()
@@ -365,7 +368,7 @@ def _run(self, rerun=False, **kwargs):
365
368
self .hooks .pre_run (self )
366
369
# TODO add signal handler for processes killed after lock acquisition
367
370
with SoftFileLock (lockfile ):
368
- if not rerun :
371
+ if not ( rerun or self . task_rerun ) :
369
372
result = self .result ()
370
373
if result is not None :
371
374
return result
@@ -610,6 +613,8 @@ def __init__(
610
613
messenger_args = None ,
611
614
messengers = None ,
612
615
output_spec : ty .Optional [BaseSpec ] = None ,
616
+ rerun = False ,
617
+ propagate_rerun = True ,
613
618
** kwargs ,
614
619
):
615
620
"""
@@ -671,13 +676,16 @@ def __init__(
671
676
audit_flags = audit_flags ,
672
677
messengers = messengers ,
673
678
messenger_args = messenger_args ,
679
+ rerun = rerun ,
674
680
)
675
681
676
682
self .graph = DiGraph ()
677
683
self .name2obj = {}
678
684
679
685
# store output connections
680
686
self ._connections = None
687
+ # propagating rerun if task_rerun=True
688
+ self .propagate_rerun = propagate_rerun
681
689
682
690
def __getattr__ (self , name ):
683
691
if name == "lzin" :
@@ -789,12 +797,20 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
789
797
checksum = self .checksum
790
798
lockfile = self .cache_dir / (checksum + ".lock" )
791
799
# Eagerly retrieve cached
792
- if not rerun :
800
+ if not ( rerun or self . task_rerun ) :
793
801
result = self .result ()
794
802
if result is not None :
795
803
return result
796
804
# creating connections that were defined after adding tasks to the wf
797
805
for task in self .graph .nodes :
806
+ # if workflow has task_rerun=True and propagate_rerun=True,
807
+ # it should be passed to the tasks
808
+ if self .task_rerun and self .propagate_rerun :
809
+ task .task_rerun = self .task_rerun
810
+ # if the task is a wf, than the propagate_rerun should be also set
811
+ if is_workflow (task ):
812
+ task .propagate_rerun = self .propagate_rerun
813
+ task .cache_locations = task ._cache_locations + self .cache_locations
798
814
self .create_connections (task )
799
815
# TODO add signal handler for processes killed after lock acquisition
800
816
self .hooks .pre_run (self )
@@ -810,7 +826,7 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
810
826
self .hooks .pre_run_task (self )
811
827
try :
812
828
self .audit .monitor ()
813
- await self ._run_task (submitter )
829
+ await self ._run_task (submitter , rerun = rerun )
814
830
result .output = self ._collect_outputs ()
815
831
except Exception as e :
816
832
record_error (self .output_dir , e )
@@ -824,11 +840,11 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
824
840
self .hooks .post_run (self , result )
825
841
return result
826
842
827
- async def _run_task (self , submitter ):
843
+ async def _run_task (self , submitter , rerun = False ):
828
844
if not submitter :
829
845
raise Exception ("Submitter should already be set." )
830
846
# at this point Workflow is stateless so this should be fine
831
- await submitter ._run_workflow (self )
847
+ await submitter ._run_workflow (self , rerun = rerun )
832
848
833
849
def set_output (self , connections ):
834
850
"""
0 commit comments