Skip to content

Commit 713939b

Browse files
committed
debugged test_tasks
1 parent a368c40 commit 713939b

File tree

4 files changed

+88
-63
lines changed

4 files changed

+88
-63
lines changed

pydra/engine/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def __init__(
100100
name: str,
101101
environment: "Environment | None" = None,
102102
state_index: "state.StateIndex | None" = None,
103+
pre_run: ty.Callable["Task", None] | None = None,
104+
pre_run_task: ty.Callable["Task", None] | None = None,
105+
post_run_task: ty.Callable["Task", None] | None = None,
106+
post_run: ty.Callable["Task", None] | None = None,
103107
):
104108
"""
105109
Initialize a task.
@@ -142,7 +146,12 @@ def __init__(
142146
self.allow_cache_override = True
143147
self._checksum = None
144148
self._uid = uuid4().hex
145-
self.hooks = TaskHook()
149+
self.hooks = TaskHook(
150+
pre_run=pre_run,
151+
post_run=post_run,
152+
pre_run_task=pre_run_task,
153+
post_run_task=post_run_task,
154+
)
146155
self._errored = False
147156
self._lzout = None
148157

pydra/engine/specs.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from copy import deepcopy
1515
from typing import Self
1616
import attrs
17+
from attrs.converters import default_if_none
1718
import cloudpickle as cp
1819
from fileformats.core import FileSet
1920
from pydra.utils.messenger import AuditFlag, Messenger
@@ -160,6 +161,10 @@ def __call__(
160161
messengers: ty.Iterable[Messenger] | None = None,
161162
messenger_args: dict[str, ty.Any] | None = None,
162163
name: str | None = None,
164+
pre_run: ty.Callable["Task", None] | None = None,
165+
post_run: ty.Callable["Task", None] | None = None,
166+
pre_run_task: ty.Callable["Task", None] | None = None,
167+
post_run_task: ty.Callable["Task", None] | None = None,
163168
**kwargs: ty.Any,
164169
) -> OutputsType:
165170
"""Create a task from this definition and execute it to produce a result.
@@ -212,8 +217,17 @@ def __call__(
212217
worker=worker,
213218
**kwargs,
214219
) as sub:
215-
result = sub(self, name=name)
220+
result = sub(
221+
self,
222+
name=name,
223+
pre_run=pre_run,
224+
post_run=post_run,
225+
pre_run_task=pre_run_task,
226+
post_run_task=post_run_task,
227+
)
216228
except TypeError as e:
229+
# Catch any inadvertent passing of task definition parameters to the
230+
# execution call
217231
if hasattr(e, "__notes__") and WORKER_KWARG_FAIL_NOTE in e.__notes__:
218232
if match := re.match(
219233
r".*got an unexpected keyword argument '(\w+)'", str(e)
@@ -1248,15 +1262,18 @@ def donothing(*args: ty.Any, **kwargs: ty.Any) -> None:
12481262
class TaskHook:
12491263
"""Callable task hooks."""
12501264

1251-
pre_run_task: ty.Callable = donothing
1252-
post_run_task: ty.Callable = donothing
1253-
pre_run: ty.Callable = donothing
1254-
post_run: ty.Callable = donothing
1255-
1256-
def __setattr__(self, attr, val):
1257-
if attr not in ["pre_run_task", "post_run_task", "pre_run", "post_run"]:
1258-
raise AttributeError("Cannot set unknown hook")
1259-
super().__setattr__(attr, val)
1265+
pre_run_task: ty.Callable = attrs.field(
1266+
default=donothing, converter=default_if_none(donothing)
1267+
)
1268+
post_run_task: ty.Callable = attrs.field(
1269+
default=donothing, converter=default_if_none(donothing)
1270+
)
1271+
pre_run: ty.Callable = attrs.field(
1272+
default=donothing, converter=default_if_none(donothing)
1273+
)
1274+
post_run: ty.Callable = attrs.field(
1275+
default=donothing, converter=default_if_none(donothing)
1276+
)
12601277

12611278
def reset(self):
12621279
for val in ["pre_run_task", "post_run_task", "pre_run", "post_run"]:

pydra/engine/submitter.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def __call__(
169169
self,
170170
task_def: "TaskDef",
171171
name: str | None = "task",
172+
pre_run: ty.Callable["Task", None] | None = None,
173+
post_run: ty.Callable["Task", None] | None = None,
174+
pre_run_task: ty.Callable["Task", None] | None = None,
175+
post_run_task: ty.Callable["Task", None] | None = None,
172176
):
173177
"""Submitter run function."""
174178

@@ -194,7 +198,16 @@ def Split(defn: TaskDef, output_types: dict):
194198
f"Task {self} is marked for combining, but not splitting. "
195199
"Use the `split` method to split the task before combining."
196200
)
197-
task = Task(task_def, submitter=self, name=name, environment=self.environment)
201+
task = Task(
202+
task_def,
203+
submitter=self,
204+
name=name,
205+
environment=self.environment,
206+
pre_run=pre_run,
207+
post_run=post_run,
208+
pre_run_task=pre_run_task,
209+
post_run_task=post_run_task,
210+
)
198211
try:
199212
self.run_start_time = datetime.now()
200213
if self.worker.is_async: # Only workflow tasks can be async
@@ -203,6 +216,12 @@ def Split(defn: TaskDef, output_types: dict):
203216
)
204217
else:
205218
self.worker.run(task, rerun=self.rerun)
219+
except Exception as e:
220+
e.add_note(
221+
f"Full crash report for {type(task_def).__name__!r} task is here: "
222+
+ str(task.output_dir / "_error.pklz")
223+
)
224+
raise e
206225
finally:
207226
self.run_start_time = None
208227
PersistentCache().clean_up()

pydra/engine/tests/test_task.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ def test_taskhooks_1(tmpdir: Path, capsys):
11981198
cache_dir.mkdir()
11991199

12001200
foo = Task(
1201-
definition=FunAddTwo(a=1), submitter=Submitter(cache_dir=cache_dir), name="foo"
1201+
definition=FunAddTwo(a=1), submitter=Submitter(cache_dir=tmpdir), name="foo"
12021202
)
12031203
assert foo.hooks
12041204
# ensure all hooks are defined
@@ -1209,8 +1209,7 @@ def test_taskhooks_1(tmpdir: Path, capsys):
12091209
def myhook(task, *args):
12101210
print("I was called")
12111211

1212-
foo.hooks.pre_run = myhook
1213-
foo.run()
1212+
FunAddTwo(a=1)(cache_dir=cache_dir, pre_run=myhook)
12141213
captured = capsys.readouterr()
12151214
assert "I was called\n" in captured.out
12161215
del captured
@@ -1219,52 +1218,31 @@ def myhook(task, *args):
12191218
with pytest.raises(AttributeError):
12201219
foo.hooks.mid_run = myhook
12211220

1221+
# reset all hooks
1222+
foo.hooks.reset()
1223+
for attr in ("pre_run", "post_run", "pre_run_task", "post_run_task"):
1224+
hook = getattr(foo.hooks, attr)
1225+
assert hook() is None
1226+
12221227
# clear cache
12231228
shutil.rmtree(cache_dir)
12241229
cache_dir.mkdir()
12251230

12261231
# set all hooks
1227-
foo.hooks.post_run = myhook
1228-
foo.hooks.pre_run_task = myhook
1229-
foo.hooks.post_run_task = myhook
1230-
foo.run()
1231-
captured = capsys.readouterr()
1232-
assert captured.out.count("I was called\n") == 4
1233-
del captured
1234-
1235-
# hooks are independent across tasks by default
1236-
bar = Task(
1237-
definition=FunAddTwo(a=3), name="bar", submitter=Submitter(cache_dir=tmpdir)
1232+
FunAddTwo(a=1)(
1233+
cache_dir=cache_dir,
1234+
pre_run=myhook,
1235+
post_run=myhook,
1236+
pre_run_task=myhook,
1237+
post_run_task=myhook,
12381238
)
1239-
assert bar.hooks is not foo.hooks
1240-
# but can be shared across tasks
1241-
bar.hooks = foo.hooks
1242-
# and workflows
1243-
wf_task = Task(
1244-
definition=BasicWorkflow(x=1),
1245-
submitter=Submitter(cache_dir=tmpdir, worker="cf"),
1246-
name="wf",
1247-
)
1248-
wf_task.hooks = bar.hooks
1249-
assert foo.hooks == bar.hooks == wf_task.hooks
1250-
1251-
wf_task.run()
12521239
captured = capsys.readouterr()
12531240
assert captured.out.count("I was called\n") == 4
12541241
del captured
12551242

1256-
# reset all hooks
1257-
foo.hooks.reset()
1258-
for attr in ("pre_run", "post_run", "pre_run_task", "post_run_task"):
1259-
hook = getattr(foo.hooks, attr)
1260-
assert hook() is None
1261-
12621243

12631244
def test_taskhooks_2(tmpdir, capsys):
12641245
"""checking order of the hooks; using task's attributes"""
1265-
foo = Task(
1266-
definition=FunAddTwo(a=1), name="foo", submitter=Submitter(cache_dir=tmpdir)
1267-
)
12681246

12691247
def myhook_prerun(task, *args):
12701248
print(f"i. prerun hook was called from {task.name}")
@@ -1278,11 +1256,13 @@ def myhook_postrun_task(task, *args):
12781256
def myhook_postrun(task, *args):
12791257
print(f"iv. postrun hook was called {task.name}")
12801258

1281-
foo.hooks.pre_run = myhook_prerun
1282-
foo.hooks.post_run = myhook_postrun
1283-
foo.hooks.pre_run_task = myhook_prerun_task
1284-
foo.hooks.post_run_task = myhook_postrun_task
1285-
foo.run()
1259+
FunAddTwo(a=1)(
1260+
cache_dir=tmpdir,
1261+
pre_run=myhook_prerun,
1262+
post_run=myhook_postrun,
1263+
pre_run_task=myhook_prerun_task,
1264+
post_run_task=myhook_postrun_task,
1265+
)
12861266

12871267
captured = capsys.readouterr()
12881268
hook_messages = captured.out.strip().split("\n")
@@ -1318,21 +1298,17 @@ def myhook_postrun(task, result, *args):
13181298

13191299
def test_taskhooks_4(tmpdir, capsys):
13201300
"""task raises an error: postrun task should be called, postrun shouldn't be called"""
1321-
foo = Task(
1322-
definition=FunAddTwo(a="one"), name="foo", submitter=Submitter(cache_dir=tmpdir)
1323-
)
13241301

13251302
def myhook_postrun_task(task, result, *args):
13261303
print(f"postrun task hook was called, result object is {result}")
13271304

13281305
def myhook_postrun(task, result, *args):
13291306
print("postrun hook should not be called")
13301307

1331-
foo.hooks.post_run = myhook_postrun
1332-
foo.hooks.post_run_task = myhook_postrun_task
1333-
13341308
with pytest.raises(Exception):
1335-
foo()
1309+
FunAddTwo(a="one")(
1310+
cache_dir=tmpdir, post_run=myhook_postrun, post_run_task=myhook_postrun_task
1311+
)
13361312

13371313
captured = capsys.readouterr()
13381314
hook_messages = captured.out.strip().split("\n")
@@ -1351,11 +1327,13 @@ def test_traceback(tmpdir):
13511327
def FunError(x):
13521328
raise Exception("Error from the function")
13531329

1354-
with pytest.raises(Exception, match="Task 'FunError' raised an error") as exinfo:
1330+
with pytest.raises(Exception, match="Error from the function") as exinfo:
13551331
FunError(x=3)(worker="cf", cache_dir=tmpdir)
13561332

13571333
# getting error file from the error message
1358-
error_file_match = str(exinfo.value).split("here: ")[-1].split("_error.pklz")[0]
1334+
error_file_match = (
1335+
str(exinfo.value.__notes__[0]).split("here: ")[-1].split("_error.pklz")[0]
1336+
)
13591337
error_file = Path(error_file_match) / "_error.pklz"
13601338
# checking if the file exists
13611339
assert error_file.exists()
@@ -1386,7 +1364,9 @@ def Workflow(x_list):
13861364
wf(worker="cf")
13871365

13881366
# getting error file from the error message
1389-
error_file_match = str(exinfo.value).split("here: ")[-1].split("_error.pklz")[0]
1367+
error_file_match = (
1368+
str(exinfo.value).split("here: ")[-1].split("_error.pklz")[0].strip()
1369+
)
13901370
error_file = Path(error_file_match) / "_error.pklz"
13911371
# checking if the file exists
13921372
assert error_file.exists()

0 commit comments

Comments
 (0)