Skip to content

Commit e5c2556

Browse files
committed
implemented non-asynchronous "debug" worker that avoids async/await for easier debugging of workflows
1 parent abcf93f commit e5c2556

File tree

10 files changed

+384
-289
lines changed

10 files changed

+384
-289
lines changed

new-docs/source/tutorial/tst.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
# Instantiate the task definition, "splitting" over all NIfTI files in the test directory
2323
# by splitting the "input" input field over all files in the directory
24-
mrgrid = MrGrid(voxel=(0.5, 0.5, 0.5)).split(in_file=nifti_dir.iterdir())
24+
mrgrid = MrGrid(operation="regrid", voxel=(0.5, 0.5, 0.5)).split(
25+
in_file=nifti_dir.iterdir()
26+
)
2527

2628
# Run the task to resample all NIfTI files
27-
outputs = mrgrid(worker="serial")
29+
outputs = mrgrid(worker="cf")
2830

2931
# Print the locations of the output files
30-
print("\n".join(str(p) for p in outputs.outputs))
32+
print("\n".join(str(p) for p in outputs.out_file))

pydra/design/base.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,7 @@ def make_task_def(
388388
klass : type
389389
The class created using the attrs package
390390
"""
391-
from pydra.engine.specs import TaskDef, WorkflowDef
392-
from pydra.engine.core import Task, WorkflowTask
391+
from pydra.engine.specs import TaskDef
393392

394393
spec_type._check_arg_refs(inputs, outputs)
395394

@@ -400,7 +399,6 @@ def make_task_def(
400399
f"{reserved_names} are reserved and cannot be used for {spec_type} field names"
401400
)
402401
outputs_klass = make_outputs_spec(out_type, outputs, outputs_bases, name)
403-
task_type = WorkflowTask if issubclass(spec_type, WorkflowDef) else Task
404402
if klass is None or not issubclass(klass, spec_type):
405403
if name is None:
406404
raise ValueError("name must be provided if klass is not")
@@ -419,19 +417,13 @@ def make_task_def(
419417
name=name,
420418
bases=bases,
421419
kwds={},
422-
exec_body=lambda ns: ns.update(
423-
{
424-
"Outputs": outputs_klass,
425-
"Task": task_type,
426-
}
427-
),
420+
exec_body=lambda ns: ns.update({"Outputs": outputs_klass}),
428421
)
429422
else:
430423
# Ensure that the class has it's own annotations dict so we can modify it without
431424
# messing up other classes
432425
klass.__annotations__ = copy(klass.__annotations__)
433426
klass.Outputs = outputs_klass
434-
klass.Task = task_type
435427
# Now that we have saved the attributes in lists to be
436428
for arg in inputs.values():
437429
# If an outarg input then the field type should be Path not a FileSet

pydra/engine/core.py

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from .helpers_file import copy_nested_files, template_update
4545
from pydra.utils.messenger import AuditFlag
46+
from pydra.engine.environments import Environment, Native
4647

4748
logger = logging.getLogger("pydra")
4849

@@ -85,7 +86,8 @@ class Task(ty.Generic[DefType]):
8586

8687
name: str
8788
definition: DefType
88-
submitter: "Submitter"
89+
submitter: "Submitter | None"
90+
environment: "Environment | None"
8991
state_index: state.StateIndex
9092

9193
_inputs: dict[str, ty.Any] | None = None
@@ -95,6 +97,7 @@ def __init__(
9597
definition: DefType,
9698
submitter: "Submitter",
9799
name: str,
100+
environment: "Environment | None" = None,
98101
state_index: "state.StateIndex | None" = None,
99102
):
100103
"""
@@ -121,14 +124,16 @@ def __init__(
121124
if state_index is None:
122125
state_index = state.StateIndex()
123126

124-
self.definition = definition
127+
# Copy the definition, so lazy fields can be resolved and replaced at runtime
128+
self.definition = copy(definition)
129+
# We save the submitter is the definition is a workflow otherwise we don't
130+
# so the task can be pickled
131+
self.submitter = submitter if is_workflow(definition) else None
132+
self.environment = environment if environment is not None else Native()
125133
self.name = name
126134
self.state_index = state_index
127135

128-
# checking if metadata is set properly
129-
self.definition._check_resolved()
130-
self.definition._check_rules()
131-
self._output = {}
136+
self.return_values = {}
132137
self._result = {}
133138
# flag that says if node finished all jobs
134139
self._done = False
@@ -151,6 +156,11 @@ def __init__(
151156
def cache_dir(self):
152157
return self._cache_dir
153158

159+
@property
160+
def is_async(self) -> bool:
161+
"""Check to see if the task should be run asynchronously."""
162+
return self.submitter.worker.is_async and is_workflow(self.definition)
163+
154164
@cache_dir.setter
155165
def cache_dir(self, path: os.PathLike):
156166
self._cache_dir = Path(path)
@@ -315,6 +325,21 @@ def _populate_filesystem(self):
315325
self.output_dir.mkdir(parents=False, exist_ok=self.can_resume)
316326

317327
def run(self, rerun: bool = False):
328+
"""Prepare the task working directory, execute the task definition, and save the
329+
results.
330+
331+
Parameters
332+
----------
333+
rerun : bool
334+
If True, the task will be re-run even if a result already exists. Will
335+
propagated to all tasks within workflow tasks.
336+
"""
337+
# TODO: After these changes have been merged, will refactor this function and
338+
# run_async to use common helper methods for pre/post run tasks
339+
340+
# checking if the definition is fully resolved and ready to run
341+
self.definition._check_resolved()
342+
self.definition._check_rules()
318343
self.hooks.pre_run(self)
319344
logger.debug(
320345
"'%s' is attempting to acquire lock on %s", self.name, self.lockfile
@@ -334,8 +359,8 @@ def run(self, rerun: bool = False):
334359
self.audit.audit_task(task=self)
335360
try:
336361
self.audit.monitor()
337-
run_outputs = self.definition._run(self)
338-
result.outputs = self.definition.Outputs.from_task(self, run_outputs)
362+
self.definition._run(self)
363+
result.outputs = self.definition.Outputs._from_task(self)
339364
except Exception:
340365
etype, eval, etr = sys.exc_info()
341366
traceback = format_exception(etype, eval, etr)
@@ -355,6 +380,57 @@ def run(self, rerun: bool = False):
355380
self._check_for_hash_changes()
356381
return result
357382

383+
async def run_async(self, rerun: bool = False):
384+
"""Prepare the task working directory, execute the task definition asynchronously,
385+
and save the results. NB: only workflows are run asynchronously at the moment.
386+
387+
Parameters
388+
----------
389+
rerun : bool
390+
If True, the task will be re-run even if a result already exists. Will
391+
propagated to all tasks within workflow tasks.
392+
"""
393+
# checking if the definition is fully resolved and ready to run
394+
self.definition._check_resolved()
395+
self.definition._check_rules()
396+
self.hooks.pre_run(self)
397+
logger.debug(
398+
"'%s' is attempting to acquire lock on %s", self.name, self.lockfile
399+
)
400+
async with PydraFileLock(self.lockfile):
401+
if not rerun:
402+
result = self.result()
403+
if result is not None and not result.errored:
404+
return result
405+
cwd = os.getcwd()
406+
self._populate_filesystem()
407+
result = Result(outputs=None, runtime=None, errored=False, task=self)
408+
self.hooks.pre_run_task(self)
409+
self.audit.start_audit(odir=self.output_dir)
410+
try:
411+
self.audit.monitor()
412+
await self.definition._run(self)
413+
result.outputs = self.definition.Outputs._from_task(self)
414+
except Exception:
415+
etype, eval, etr = sys.exc_info()
416+
traceback = format_exception(etype, eval, etr)
417+
record_error(self.output_dir, error=traceback)
418+
result.errored = True
419+
self._errored = True
420+
raise
421+
finally:
422+
self.hooks.post_run_task(self, result)
423+
self.audit.finalize_audit(result=result)
424+
save(self.output_dir, result=result, task=self)
425+
# removing the additional file with the checksum
426+
(self.cache_dir / f"{self.uid}_info.json").unlink()
427+
os.chdir(cwd)
428+
self.hooks.post_run(self, result)
429+
# Check for any changes to the input hashes that have occurred during the execution
430+
# of the task
431+
self._check_for_hash_changes()
432+
return result
433+
358434
def pickle_task(self):
359435
"""Pickling the tasks with full inputs"""
360436
pkl_files = self.cache_dir / "pkl_files"
@@ -398,7 +474,7 @@ def _combined_output(self, return_inputs=False):
398474
else:
399475
return combined_results
400476

401-
def result(self, state_index=None, return_inputs=False):
477+
def result(self, return_inputs=False):
402478
"""
403479
Retrieve the outcomes of this particular task.
404480
@@ -415,13 +491,9 @@ def result(self, state_index=None, return_inputs=False):
415491
result : Result
416492
the result of the task
417493
"""
418-
# TODO: check if result is available in load_result and
419-
# return a future if not
420494
if self.errored:
421495
return Result(outputs=None, runtime=None, errored=True, task=self)
422496

423-
if state_index is not None:
424-
raise ValueError("Task does not have a state")
425497
checksum = self.checksum
426498
result = load_result(checksum, self.cache_locations)
427499
if result and result.errored:
@@ -483,58 +555,6 @@ def _check_for_hash_changes(self):
483555
DEFAULT_COPY_COLLATION = FileSet.CopyCollation.any
484556

485557

486-
class WorkflowTask(Task):
487-
488-
def __init__(
489-
self,
490-
definition: DefType,
491-
submitter: "Submitter",
492-
name: str,
493-
state_index: "state.StateIndex | None" = None,
494-
):
495-
super().__init__(definition, submitter, name, state_index)
496-
self.submitter = submitter
497-
498-
async def run(self, rerun: bool = False):
499-
self.hooks.pre_run(self)
500-
logger.debug(
501-
"'%s' is attempting to acquire lock on %s", self.name, self.lockfile
502-
)
503-
async with PydraFileLock(self.lockfile):
504-
if not rerun:
505-
result = self.result()
506-
if result is not None and not result.errored:
507-
return result
508-
cwd = os.getcwd()
509-
self._populate_filesystem()
510-
result = Result(outputs=None, runtime=None, errored=False, task=self)
511-
self.hooks.pre_run_task(self)
512-
self.audit.start_audit(odir=self.output_dir)
513-
try:
514-
self.audit.monitor()
515-
await self.submitter.expand_workflow(self)
516-
result.outputs = self.definition.Outputs.from_task(self)
517-
except Exception:
518-
etype, eval, etr = sys.exc_info()
519-
traceback = format_exception(etype, eval, etr)
520-
record_error(self.output_dir, error=traceback)
521-
result.errored = True
522-
self._errored = True
523-
raise
524-
finally:
525-
self.hooks.post_run_task(self, result)
526-
self.audit.finalize_audit(result=result)
527-
save(self.output_dir, result=result, task=self)
528-
# removing the additional file with the checksum
529-
(self.cache_dir / f"{self.uid}_info.json").unlink()
530-
os.chdir(cwd)
531-
self.hooks.post_run(self, result)
532-
# Check for any changes to the input hashes that have occurred during the execution
533-
# of the task
534-
self._check_for_hash_changes()
535-
return result
536-
537-
538558
logger = logging.getLogger("pydra")
539559

540560
OutputsType = ty.TypeVar("OutputType", bound=TaskOutputs)
@@ -847,18 +867,12 @@ def create_dotfile(self, type="simple", export=None, name=None, output_dir=None)
847867
return dotfile, formatted_dot
848868

849869

850-
def is_task(obj):
851-
"""Check whether an object looks like a task."""
852-
return hasattr(obj, "_run_task")
853-
854-
855870
def is_workflow(obj):
856871
"""Check whether an object is a :class:`Workflow` instance."""
857872
from pydra.engine.specs import WorkflowDef
858873
from pydra.engine.core import Workflow
859-
from pydra.engine.core import WorkflowTask
860874

861-
return isinstance(obj, (WorkflowDef, WorkflowTask, Workflow))
875+
return isinstance(obj, (WorkflowDef, Workflow))
862876

863877

864878
def has_lazy(obj):

pydra/engine/environments.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
if ty.TYPE_CHECKING:
6-
from pydra.engine.task import ShellTask
6+
from pydra.engine.core import Task
77

88

99
class Environment:
@@ -17,7 +17,7 @@ class Environment:
1717
def setup(self):
1818
pass
1919

20-
def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
20+
def execute(self, task: "Task") -> dict[str, ty.Any]:
2121
"""
2222
Execute the task in the environment.
2323
@@ -42,12 +42,12 @@ class Native(Environment):
4242
Native environment, i.e. the tasks are executed in the current python environment.
4343
"""
4444

45-
def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
45+
def execute(self, task: "Task") -> dict[str, ty.Any]:
4646
keys = ["return_code", "stdout", "stderr"]
47-
values = execute(task.definition._command_args(), strip=task.strip)
47+
values = execute(task.definition._command_args())
4848
output = dict(zip(keys, values))
4949
if output["return_code"]:
50-
msg = f"Error running '{task.name}' task with {task.command_args()}:"
50+
msg = f"Error running '{task.name}' task with {task.definition._command_args()}:"
5151
if output["stderr"]:
5252
msg += "\n\nstderr:\n" + output["stderr"]
5353
if output["stdout"]:
@@ -90,7 +90,7 @@ def bind(self, loc, mode="ro"):
9090
class Docker(Container):
9191
"""Docker environment."""
9292

93-
def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
93+
def execute(self, task: "Task") -> dict[str, ty.Any]:
9494
docker_img = f"{self.image}:{self.tag}"
9595
# mounting all input locations
9696
mounts = task.definition._get_bindings(root=self.root)
@@ -112,7 +112,6 @@ def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
112112

113113
values = execute(
114114
docker_args + [docker_img] + task.definition._command_args(root=self.root),
115-
strip=task.strip,
116115
)
117116
output = dict(zip(keys, values))
118117
if output["return_code"]:
@@ -126,7 +125,7 @@ def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
126125
class Singularity(Container):
127126
"""Singularity environment."""
128127

129-
def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
128+
def execute(self, task: "Task") -> dict[str, ty.Any]:
130129
singularity_img = f"{self.image}:{self.tag}"
131130
# mounting all input locations
132131
mounts = task.definition._get_bindings(root=self.root)
@@ -151,7 +150,6 @@ def execute(self, task: "ShellTask") -> dict[str, ty.Any]:
151150
singularity_args
152151
+ [singularity_img]
153152
+ task.definition._command_args(root=self.root),
154-
strip=task.strip,
155153
)
156154
output = dict(zip(keys, values))
157155
if output["return_code"]:

0 commit comments

Comments
 (0)