Skip to content

Commit a2e3804

Browse files
committed
checkpoint exceptions
1 parent 87d6413 commit a2e3804

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

parsl/dataflow/memoization.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import typeguard
1212

1313
from parsl.dataflow.errors import BadCheckpoint
14+
from parsl.dataflow.futures import AppFuture
1415
from parsl.dataflow.taskrecord import TaskRecord
1516

1617
if TYPE_CHECKING:
@@ -339,8 +340,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A
339340
data = pickle.load(f)
340341
# Copy and hash only the input attributes
341342
memo_fu: Future = Future()
342-
assert data['exception'] is None
343-
memo_fu.set_result(data['result'])
343+
344+
if data['exception'] is None:
345+
memo_fu.set_result(data['result'])
346+
else:
347+
assert data['result'] is None
348+
memo_fu.set_exception(data['exception'])
344349
memo_lookup_table[data['hash']] = memo_fu
345350

346351
except EOFError:
@@ -414,17 +419,22 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
414419

415420
app_fu = task_record['app_fu']
416421

417-
if app_fu.done() and app_fu.exception() is None:
422+
if app_fu.done() and self.filter_for_checkpoint(app_fu):
423+
418424
hashsum = task_record['hashsum']
419425
if not hashsum:
420426
continue
421-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
427+
428+
if app_fu.exception() is None:
429+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
430+
else:
431+
t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None}
422432

423433
# We are using pickle here since pickle dumps to a file in 'ab'
424434
# mode behave like a incremental log.
425435
pickle.dump(t, f)
426436
count += 1
427-
logger.debug("Task {} checkpointed".format(task_id))
437+
logger.debug("Task {} checkpointed as result".format(task_id))
428438

429439
self.checkpointed_tasks += count
430440

@@ -435,3 +445,7 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
435445
logger.debug("No tasks checkpointed in this pass.")
436446
else:
437447
logger.info("Done checkpointing {} tasks".format(count))
448+
449+
def filter_for_checkpoint(self, app_fu: AppFuture) -> bool:
450+
"""Overridable method to decide if an entry should be checkpointed"""
451+
return app_fu.exception() is None
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import contextlib
2+
import os
3+
4+
import pytest
5+
6+
import parsl
7+
from parsl import python_app
8+
from parsl.config import Config
9+
from parsl.dataflow.memoization import BasicMemoizer
10+
from parsl.executors.threads import ThreadPoolExecutor
11+
12+
13+
class CheckpointExceptionsMemoizer(BasicMemoizer):
14+
def filter_for_checkpoint(self, app_fu):
15+
# checkpoint everything, rather than selecting only futures with
16+
# results, not exceptions.
17+
18+
# task record is available from app_fu.task_record
19+
assert app_fu.task_record is not None
20+
21+
return True
22+
23+
24+
def fresh_config():
25+
return Config(
26+
memoizer=CheckpointExceptionsMemoizer(),
27+
executors=[
28+
ThreadPoolExecutor(
29+
label='local_threads_checkpoint',
30+
)
31+
]
32+
)
33+
34+
35+
@contextlib.contextmanager
36+
def parsl_configured(run_dir, **kw):
37+
c = fresh_config()
38+
c.run_dir = run_dir
39+
for config_attr, config_val in kw.items():
40+
setattr(c, config_attr, config_val)
41+
dfk = parsl.load(c)
42+
for ex in dfk.executors.values():
43+
ex.working_dir = run_dir
44+
yield dfk
45+
46+
parsl.dfk().cleanup()
47+
48+
49+
@python_app(cache=True)
50+
def uuid_app():
51+
import uuid
52+
raise RuntimeError(str(uuid.uuid4()))
53+
54+
55+
@pytest.mark.local
56+
def test_loading_checkpoint(tmpd_cwd):
57+
"""Load memoization table from previous checkpoint
58+
"""
59+
with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"):
60+
checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")]
61+
result = uuid_app().exception()
62+
63+
with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files):
64+
relaunched = uuid_app().exception()
65+
66+
assert result.args == relaunched.args, "Expected following call to uuid_app to return cached uuid in exception"

0 commit comments

Comments
 (0)