Skip to content

Commit 1846145

Browse files
committed
checkpoint exceptions
1 parent fd311d0 commit 1846145

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

parsl/dataflow/memoization.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import typeguard
1212

1313
from parsl.dataflow.errors import BadCheckpoint
14+
from parsl.dataflow.futures import AppFuture
1415
from parsl.dataflow.taskrecord import TaskRecord
1516

1617
if TYPE_CHECKING:
@@ -336,8 +337,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A
336337
data = pickle.load(f)
337338
# Copy and hash only the input attributes
338339
memo_fu: Future = Future()
339-
assert data['exception'] is None
340-
memo_fu.set_result(data['result'])
340+
341+
if data['exception'] is None:
342+
memo_fu.set_result(data['result'])
343+
else:
344+
assert data['result'] is None
345+
memo_fu.set_exception(data['exception'])
341346
memo_lookup_table[data['hash']] = memo_fu
342347

343348
except EOFError:
@@ -418,17 +423,22 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
418423

419424
app_fu = task_record['app_fu']
420425

421-
if app_fu.done() and app_fu.exception() is None:
426+
if app_fu.done() and self.filter_for_checkpoint(app_fu):
427+
422428
hashsum = task_record['hashsum']
423429
if not hashsum:
424430
continue
425-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
431+
432+
if app_fu.exception() is None:
433+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
434+
else:
435+
t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None}
426436

427437
# We are using pickle here since pickle dumps to a file in 'ab'
428438
# mode behave like a incremental log.
429439
pickle.dump(t, f)
430440
count += 1
431-
logger.debug("Task {} checkpointed".format(task_id))
441+
logger.debug("Task {} checkpointed as result".format(task_id))
432442

433443
self.checkpointed_tasks += count
434444

@@ -441,3 +451,7 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
441451
logger.info("Done checkpointing {} tasks".format(count))
442452

443453
return checkpoint_dir
454+
455+
def filter_for_checkpoint(self, app_fu: AppFuture) -> bool:
456+
"""Overridable method to decide if an entry should be checkpointed"""
457+
return app_fu.exception() is None
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import contextlib
2+
import os
3+
4+
import pytest
5+
6+
import parsl
7+
from parsl import python_app
8+
from parsl.config import Config
9+
from parsl.dataflow.memoization import BasicMemoizer
10+
from parsl.executors.threads import ThreadPoolExecutor
11+
12+
13+
class CheckpointExceptionsMemoizer(BasicMemoizer):
14+
def filter_for_checkpoint(self, app_fu):
15+
# checkpoint everything, rather than selecting only futures with
16+
# results, not exceptions.
17+
18+
# task record is available from app_fu.task_record
19+
assert app_fu.task_record is not None
20+
21+
return True
22+
23+
24+
def fresh_config():
25+
return Config(
26+
memoizer=CheckpointExceptionsMemoizer(),
27+
executors=[
28+
ThreadPoolExecutor(
29+
label='local_threads_checkpoint',
30+
)
31+
]
32+
)
33+
34+
35+
@contextlib.contextmanager
36+
def parsl_configured(run_dir, **kw):
37+
c = fresh_config()
38+
c.run_dir = run_dir
39+
for config_attr, config_val in kw.items():
40+
setattr(c, config_attr, config_val)
41+
dfk = parsl.load(c)
42+
for ex in dfk.executors.values():
43+
ex.working_dir = run_dir
44+
yield dfk
45+
46+
parsl.dfk().cleanup()
47+
48+
49+
@python_app(cache=True)
50+
def uuid_app():
51+
import uuid
52+
raise RuntimeError(str(uuid.uuid4()))
53+
54+
55+
@pytest.mark.local
56+
def test_loading_checkpoint(tmpd_cwd):
57+
"""Load memoization table from previous checkpoint
58+
"""
59+
with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"):
60+
checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")]
61+
result = uuid_app().exception()
62+
63+
with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files):
64+
relaunched = uuid_app().exception()
65+
66+
assert result.args == relaunched.args, "Expected following call to uuid_app to return cached uuid in exception"

0 commit comments

Comments
 (0)