Skip to content

Commit b1a8f4c

Browse files
committed
TST: Decrease divergence between Task/Workflow _run_task
1 parent 74957ca commit b1a8f4c

File tree

1 file changed

+63
-59
lines changed

1 file changed

+63
-59
lines changed

pydra/engine/core.py

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -452,42 +452,53 @@ def __call__(
452452
res = self._run(rerun=rerun, **kwargs)
453453
return res
454454

455+
def prepare_run_task(self, rerun):
456+
"""
457+
Invoked immediately after the lockfile is generated, this function:
458+
- does a lot of things... (TODO)
459+
- Creates an empty Result and passes it along to be populated.
460+
461+
Created as an attempt to simplify overlapping `Task`|`Workflow` behaviors.
462+
"""
463+
# retrieve cached results
464+
if not (rerun or self.task_rerun):
465+
result = self.result()
466+
if result is not None and not result.errored:
467+
return result
468+
# adding info file with the checksum in case the task was cancelled
469+
# and the lockfile has to be removed
470+
with open(self.cache_dir / f"{self.uid}_info.json", "w") as jsonfile:
471+
json.dump({"checksum": self.checksum}, jsonfile)
472+
if not self.can_resume and self.output_dir.exists():
473+
shutil.rmtree(self.output_dir)
474+
self.output_dir.mkdir(parents=False, exist_ok=self.can_resume)
475+
if not is_workflow(self):
476+
self._orig_inputs = attr.asdict(self.inputs, recurse=False)
477+
map_copyfiles = copyfile_input(self.inputs, self.output_dir)
478+
modified_inputs = template_update(
479+
self.inputs, self.output_dir, map_copyfiles=map_copyfiles
480+
)
481+
if modified_inputs:
482+
self.inputs = attr.evolve(self.inputs, **modified_inputs)
483+
self.audit.start_audit(odir=self.output_dir)
484+
result = Result(output=None, runtime=None, errored=False)
485+
self.hooks.pre_run_task(self)
486+
return result
487+
455488
def _run(self, rerun=False, **kwargs):
456489
self.inputs = attr.evolve(self.inputs, **kwargs)
457490
self.inputs.check_fields_input_spec()
458-
checksum = self.checksum
459-
lockfile = self.cache_dir / (checksum + ".lock")
491+
492+
lockfile = self.cache_dir / (self.checksum + ".lock")
460493
# Eagerly retrieve cached - see scenarios in __init__()
461494
self.hooks.pre_run(self)
462495
with SoftFileLock(lockfile):
463-
if not (rerun or self.task_rerun):
464-
result = self.result()
465-
if result is not None and not result.errored:
466-
return result
467-
# adding info file with the checksum in case the task was cancelled
468-
# and the lockfile has to be removed
469-
with open(self.cache_dir / f"{self.uid}_info.json", "w") as jsonfile:
470-
json.dump({"checksum": self.checksum}, jsonfile)
471-
# Let only one equivalent process run
472-
odir = self.output_dir
473-
if not self.can_resume and odir.exists():
474-
shutil.rmtree(odir)
475496
cwd = os.getcwd()
476-
odir.mkdir(parents=False, exist_ok=True if self.can_resume else False)
477-
orig_inputs = attr.asdict(self.inputs, recurse=False)
478-
map_copyfiles = copyfile_input(self.inputs, self.output_dir)
479-
modified_inputs = template_update(
480-
self.inputs, self.output_dir, map_copyfiles=map_copyfiles
481-
)
482-
if modified_inputs:
483-
self.inputs = attr.evolve(self.inputs, **modified_inputs)
484-
self.audit.start_audit(odir)
485-
result = Result(output=None, runtime=None, errored=False)
486-
self.hooks.pre_run_task(self)
497+
result = self.prepare_run_task(rerun)
487498
try:
488499
self.audit.monitor()
489500
self._run_task()
490-
result.output = self._collect_outputs(output_dir=odir)
501+
result.output = self._collect_outputs(output_dir=self.output_dir)
491502
except Exception:
492503
etype, eval, etr = sys.exc_info()
493504
traceback = format_exception(etype, eval, etr)
@@ -497,15 +508,17 @@ def _run(self, rerun=False, **kwargs):
497508
finally:
498509
self.hooks.post_run_task(self, result)
499510
self.audit.finalize_audit(result)
500-
save(odir, result=result, task=self)
511+
save(self.output_dir, result=result, task=self)
501512
self.output_ = None
502513
# removing the additional file with the chcksum
503514
(self.cache_dir / f"{self.uid}_info.json").unlink()
504515
# # function etc. shouldn't change anyway, so removing
505-
orig_inputs = dict(
506-
(k, v) for (k, v) in orig_inputs.items() if not k.startswith("_")
516+
self._orig_inputs = dict(
517+
(k, v) for (k, v) in self._orig_inputs.items() if not k.startswith("_")
507518
)
508-
self.inputs = attr.evolve(self.inputs, **orig_inputs)
519+
self.inputs = attr.evolve(self.inputs, **self._orig_inputs)
520+
# no need to propagate this
521+
del self._orig_inputs
509522
os.chdir(cwd)
510523
self.hooks.post_run(self, result)
511524
return result
@@ -1038,38 +1051,13 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
10381051
raise ValueError(
10391052
"Workflow output cannot be None, use set_output to define output(s)"
10401053
)
1041-
checksum = self.checksum
10421054
# creating connections that were defined after adding tasks to the wf
1043-
for task in self.graph.nodes:
1044-
# if workflow has task_rerun=True and propagate_rerun=True,
1045-
# it should be passed to the tasks
1046-
if self.task_rerun and self.propagate_rerun:
1047-
task.task_rerun = self.task_rerun
1048-
# if the task is a wf, than the propagate_rerun should be also set
1049-
if is_workflow(task):
1050-
task.propagate_rerun = self.propagate_rerun
1051-
task.cache_locations = task._cache_locations + self.cache_locations
1052-
self.create_connections(task)
1053-
lockfile = self.cache_dir / (checksum + ".lock")
1055+
self.connect_and_propagate_to_tasks()
1056+
lockfile = self.cache_dir / (self.checksum + ".lock")
10541057
self.hooks.pre_run(self)
10551058
async with PydraFileLock(lockfile):
1056-
# retrieve cached results
1057-
if not (rerun or self.task_rerun):
1058-
result = self.result()
1059-
if result is not None and not result.errored:
1060-
return result
1061-
# adding info file with the checksum in case the task was cancelled
1062-
# and the lockfile has to be removed
1063-
with open(self.cache_dir / f"{self.uid}_info.json", "w") as jsonfile:
1064-
json.dump({"checksum": checksum}, jsonfile)
1065-
odir = self.output_dir
1066-
if not self.can_resume and odir.exists():
1067-
shutil.rmtree(odir)
10681059
cwd = os.getcwd()
1069-
odir.mkdir(parents=False, exist_ok=True if self.can_resume else False)
1070-
self.audit.start_audit(odir=odir)
1071-
result = Result(output=None, runtime=None, errored=False)
1072-
self.hooks.pre_run_task(self)
1060+
result = self.prepare_run_task(rerun)
10731061
try:
10741062
self.audit.monitor()
10751063
await self._run_task(submitter, rerun=rerun)
@@ -1084,7 +1072,7 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
10841072
finally:
10851073
self.hooks.post_run_task(self, result)
10861074
self.audit.finalize_audit(result=result)
1087-
save(odir, result=result, task=self)
1075+
save(self.output_dir, result=result, task=self)
10881076
# removing the additional file with the chcksum
10891077
(self.cache_dir / f"{self.uid}_info.json").unlink()
10901078
os.chdir(cwd)
@@ -1226,6 +1214,22 @@ def create_dotfile(self, type="simple", export=None, name=None):
12261214
formatted_dot.append(self.graph.export_graph(dotfile=dotfile, ext=ext))
12271215
return dotfile, formatted_dot
12281216

1217+
def connect_and_propagate_to_tasks(self):
1218+
"""
1219+
Visit each node in the graph and create the connections.
1220+
Additionally checks if all tasks should be rerun.
1221+
"""
1222+
for task in self.graph.nodes:
1223+
# if workflow has task_rerun=True and propagate_rerun=True,
1224+
# it should be passed to the tasks
1225+
if self.task_rerun and self.propagate_rerun:
1226+
task.task_rerun = self.task_rerun
1227+
# if the task is a wf, than the propagate_rerun should be also set
1228+
if is_workflow(task):
1229+
task.propagate_rerun = self.propagate_rerun
1230+
task.cache_locations = task._cache_locations + self.cache_locations
1231+
self.create_connections(task)
1232+
12291233

12301234
def is_task(obj):
12311235
"""Check whether an object looks like a task."""

0 commit comments

Comments
 (0)