Skip to content

Commit d6ae5ea

Browse files
committed
moved hooks arguments into single TaskHooks object in run methods
1 parent 7c97e55 commit d6ae5ea

File tree

4 files changed

+50
-66
lines changed

4 files changed

+50
-66
lines changed

pydra/engine/core.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .specs import (
3030
RuntimeSpec,
3131
Result,
32-
TaskHook,
32+
TaskHooks,
3333
)
3434
from .helpers import (
3535
attrs_fields,
@@ -100,10 +100,7 @@ def __init__(
100100
name: str,
101101
environment: "Environment | None" = None,
102102
state_index: "state.StateIndex | None" = None,
103-
pre_run: ty.Callable["Task", None] | None = None,
104-
pre_run_task: ty.Callable["Task", None] | None = None,
105-
post_run_task: ty.Callable["Task", None] | None = None,
106-
post_run: ty.Callable["Task", None] | None = None,
103+
hooks: TaskHooks | None = None,
107104
):
108105
"""
109106
Initialize a task.
@@ -146,12 +143,7 @@ def __init__(
146143
self.allow_cache_override = True
147144
self._checksum = None
148145
self._uid = uuid4().hex
149-
self.hooks = TaskHook(
150-
pre_run=pre_run,
151-
post_run=post_run,
152-
pre_run_task=pre_run_task,
153-
post_run_task=post_run_task,
154-
)
146+
self.hooks = hooks if hooks is not None else TaskHooks()
155147
self._errored = False
156148
self._lzout = None
157149

pydra/engine/specs.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,32 @@ def __eq__(self, other: ty.Any) -> bool:
138138
OutputsType = ty.TypeVar("OutputType", bound=TaskOutputs)
139139

140140

141+
def donothing(*args: ty.Any, **kwargs: ty.Any) -> None:
142+
return None
143+
144+
145+
@attrs.define(kw_only=True)
146+
class TaskHooks:
147+
"""Callable task hooks."""
148+
149+
pre_run_task: ty.Callable = attrs.field(
150+
default=donothing, converter=default_if_none(donothing)
151+
)
152+
post_run_task: ty.Callable = attrs.field(
153+
default=donothing, converter=default_if_none(donothing)
154+
)
155+
pre_run: ty.Callable = attrs.field(
156+
default=donothing, converter=default_if_none(donothing)
157+
)
158+
post_run: ty.Callable = attrs.field(
159+
default=donothing, converter=default_if_none(donothing)
160+
)
161+
162+
def reset(self):
163+
for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]:
164+
setattr(self, val, donothing)
165+
166+
141167
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
142168
class TaskDef(ty.Generic[OutputsType]):
143169
"""Base class for all task definitions"""
@@ -161,10 +187,7 @@ def __call__(
161187
messengers: ty.Iterable[Messenger] | None = None,
162188
messenger_args: dict[str, ty.Any] | None = None,
163189
name: str | None = None,
164-
pre_run: ty.Callable["Task", None] | None = None,
165-
post_run: ty.Callable["Task", None] | None = None,
166-
pre_run_task: ty.Callable["Task", None] | None = None,
167-
post_run_task: ty.Callable["Task", None] | None = None,
190+
hooks: TaskHooks | None = None,
168191
**kwargs: ty.Any,
169192
) -> OutputsType:
170193
"""Create a task from this definition and execute it to produce a result.
@@ -220,10 +243,7 @@ def __call__(
220243
result = sub(
221244
self,
222245
name=name,
223-
pre_run=pre_run,
224-
post_run=post_run,
225-
pre_run_task=pre_run_task,
226-
post_run_task=post_run_task,
246+
hooks=hooks,
227247
)
228248
except TypeError as e:
229249
# Catch any inadvertent passing of task definition parameters to the
@@ -1254,32 +1274,6 @@ def _generated_output_names(self, stdout: str, stderr: str):
12541274
DEFAULT_COPY_COLLATION = FileSet.CopyCollation.adjacent
12551275

12561276

1257-
def donothing(*args: ty.Any, **kwargs: ty.Any) -> None:
1258-
return None
1259-
1260-
1261-
@attrs.define(kw_only=True)
1262-
class TaskHook:
1263-
"""Callable task hooks."""
1264-
1265-
pre_run_task: ty.Callable = attrs.field(
1266-
default=donothing, converter=default_if_none(donothing)
1267-
)
1268-
post_run_task: ty.Callable = attrs.field(
1269-
default=donothing, converter=default_if_none(donothing)
1270-
)
1271-
pre_run: ty.Callable = attrs.field(
1272-
default=donothing, converter=default_if_none(donothing)
1273-
)
1274-
post_run: ty.Callable = attrs.field(
1275-
default=donothing, converter=default_if_none(donothing)
1276-
)
1277-
1278-
def reset(self):
1279-
for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]:
1280-
setattr(self, val, donothing)
1281-
1282-
12831277
def split_cmd(cmd: str | None):
12841278
"""Splits a shell command line into separate arguments respecting quotes
12851279

pydra/engine/submitter.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
if ty.TYPE_CHECKING:
3131
from .node import Node
32-
from .specs import TaskDef, WorkflowDef
32+
from .specs import TaskDef, WorkflowDef, TaskHooks
3333
from .environments import Environment
3434
from .state import State
3535

@@ -169,10 +169,7 @@ def __call__(
169169
self,
170170
task_def: "TaskDef",
171171
name: str | None = "task",
172-
pre_run: ty.Callable["Task", None] | None = None,
173-
post_run: ty.Callable["Task", None] | None = None,
174-
pre_run_task: ty.Callable["Task", None] | None = None,
175-
post_run_task: ty.Callable["Task", None] | None = None,
172+
hooks: "TaskHooks | None" = None,
176173
):
177174
"""Submitter run function."""
178175

@@ -203,10 +200,7 @@ def Split(defn: TaskDef, output_types: dict):
203200
submitter=self,
204201
name=name,
205202
environment=self.environment,
206-
pre_run=pre_run,
207-
post_run=post_run,
208-
pre_run_task=pre_run_task,
209-
post_run_task=post_run_task,
203+
hooks=hooks,
210204
)
211205
try:
212206
self.run_start_time = datetime.now()

pydra/engine/tests/test_task.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
from pydra.design import python, shell, workflow
1212
from pydra.utils.messenger import FileMessenger, PrintMessenger, collect_messages
1313
from ..task import AuditFlag
14-
from pydra.engine.specs import argstr_formatting, ShellDef, ShellOutputs
14+
from pydra.engine.specs import argstr_formatting, ShellDef, ShellOutputs, TaskHooks
1515
from pydra.engine.helpers import list_fields, print_help
1616
from pydra.engine.submitter import Submitter
1717
from pydra.engine.core import Task
18-
from .utils import BasicWorkflow
1918
from pydra.utils import default_run_cache_dir
2019
from pydra.utils.typing import (
2120
MultiInputObj,
@@ -1209,7 +1208,7 @@ def test_taskhooks_1(tmpdir: Path, capsys):
12091208
def myhook(task, *args):
12101209
print("I was called")
12111210

1212-
FunAddTwo(a=1)(cache_dir=cache_dir, pre_run=myhook)
1211+
FunAddTwo(a=1)(cache_dir=cache_dir, hooks=TaskHooks(pre_run=myhook))
12131212
captured = capsys.readouterr()
12141213
assert "I was called\n" in captured.out
12151214
del captured
@@ -1231,10 +1230,12 @@ def myhook(task, *args):
12311230
# set all hooks
12321231
FunAddTwo(a=1)(
12331232
cache_dir=cache_dir,
1234-
pre_run=myhook,
1235-
post_run=myhook,
1236-
pre_run_task=myhook,
1237-
post_run_task=myhook,
1233+
hooks=TaskHooks(
1234+
pre_run=myhook,
1235+
post_run=myhook,
1236+
pre_run_task=myhook,
1237+
post_run_task=myhook,
1238+
),
12381239
)
12391240
captured = capsys.readouterr()
12401241
assert captured.out.count("I was called\n") == 4
@@ -1258,10 +1259,12 @@ def myhook_postrun(task, *args):
12581259

12591260
FunAddTwo(a=1)(
12601261
cache_dir=tmpdir,
1261-
pre_run=myhook_prerun,
1262-
post_run=myhook_postrun,
1263-
pre_run_task=myhook_prerun_task,
1264-
post_run_task=myhook_postrun_task,
1262+
hooks=TaskHooks(
1263+
pre_run=myhook_prerun,
1264+
post_run=myhook_postrun,
1265+
pre_run_task=myhook_prerun_task,
1266+
post_run_task=myhook_postrun_task,
1267+
),
12651268
)
12661269

12671270
captured = capsys.readouterr()
@@ -1307,7 +1310,8 @@ def myhook_postrun(task, result, *args):
13071310

13081311
with pytest.raises(Exception):
13091312
FunAddTwo(a="one")(
1310-
cache_dir=tmpdir, post_run=myhook_postrun, post_run_task=myhook_postrun_task
1313+
cache_dir=tmpdir,
1314+
hooks=TaskHooks(post_run=myhook_postrun, post_run_task=myhook_postrun_task),
13111315
)
13121316

13131317
captured = capsys.readouterr()

0 commit comments

Comments
 (0)