Skip to content

Commit 52c47e5

Browse files
committed
checkpoint exceptions
1 parent 82f88a7 commit 52c47e5

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

parsl/dataflow/memoization.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import typeguard
1212

1313
from parsl.dataflow.errors import BadCheckpoint
14+
from parsl.dataflow.futures import AppFuture
1415
from parsl.dataflow.taskrecord import TaskRecord
1516

1617
if TYPE_CHECKING:
@@ -336,8 +337,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A
336337
data = pickle.load(f)
337338
# Copy and hash only the input attributes
338339
memo_fu: Future = Future()
339-
assert data['exception'] is None
340-
memo_fu.set_result(data['result'])
340+
341+
if data['exception'] is None:
342+
memo_fu.set_result(data['result'])
343+
else:
344+
assert data['result'] is None
345+
memo_fu.set_exception(data['exception'])
341346
memo_lookup_table[data['hash']] = memo_fu
342347

343348
except EOFError:
@@ -411,17 +416,22 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
411416

412417
app_fu = task_record['app_fu']
413418

414-
if app_fu.done() and app_fu.exception() is None:
419+
if app_fu.done() and self.filter_for_checkpoint(app_fu):
420+
415421
hashsum = task_record['hashsum']
416422
if not hashsum:
417423
continue
418-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
424+
425+
if app_fu.exception() is None:
426+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
427+
else:
428+
t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None}
419429

420430
# We are using pickle here since pickle dumps to a file in 'ab'
421431
# mode behave like a incremental log.
422432
pickle.dump(t, f)
423433
count += 1
424-
logger.debug("Task {} checkpointed".format(task_id))
434+
logger.debug("Task {} checkpointed as result".format(task_id))
425435

426436
self.checkpointed_tasks += count
427437

@@ -434,3 +444,7 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
434444
logger.info("Done checkpointing {} tasks".format(count))
435445

436446
return checkpoint_dir
447+
448+
def filter_for_checkpoint(self, app_fu: AppFuture) -> bool:
449+
"""Overridable method to decide if an entry should be checkpointed"""
450+
return app_fu.exception() is None
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import contextlib
2+
import os
3+
4+
import pytest
5+
6+
import parsl
7+
from parsl import python_app
8+
from parsl.config import Config
9+
from parsl.dataflow.memoization import BasicMemoizer
10+
from parsl.executors.threads import ThreadPoolExecutor
11+
12+
13+
class CheckpointExceptionsMemoizer(BasicMemoizer):
14+
def filter_for_checkpoint(self, app_fu):
15+
# checkpoint everything, rather than selecting only futures with
16+
# results, not exceptions.
17+
18+
# task record is available from app_fu.task_record
19+
assert app_fu.task_record is not None
20+
21+
return True
22+
23+
24+
def fresh_config():
25+
return Config(
26+
memoizer=CheckpointExceptionsMemoizer(),
27+
executors=[
28+
ThreadPoolExecutor(
29+
label='local_threads_checkpoint',
30+
)
31+
]
32+
)
33+
34+
35+
@contextlib.contextmanager
36+
def parsl_configured(run_dir, **kw):
37+
c = fresh_config()
38+
c.run_dir = run_dir
39+
for config_attr, config_val in kw.items():
40+
setattr(c, config_attr, config_val)
41+
dfk = parsl.load(c)
42+
for ex in dfk.executors.values():
43+
ex.working_dir = run_dir
44+
yield dfk
45+
46+
parsl.dfk().cleanup()
47+
48+
49+
@python_app(cache=True)
50+
def uuid_app():
51+
import uuid
52+
raise RuntimeError(str(uuid.uuid4()))
53+
54+
55+
@pytest.mark.local
56+
def test_loading_checkpoint(tmpd_cwd):
57+
"""Load memoization table from previous checkpoint
58+
"""
59+
with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"):
60+
checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")]
61+
result = uuid_app().exception()
62+
63+
with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files):
64+
relaunched = uuid_app().exception()
65+
66+
assert result.args == relaunched.args, "Expected following call to uuid_app to return cached uuid in exception"

0 commit comments

Comments
 (0)