Skip to content

Commit c86dec4

Browse files
committed
make checkpoint move away from always using futures for result value
and happen in update memo
1 parent 895a6ec commit c86dec4

File tree

2 files changed

+68
-31
lines changed

2 files changed

+68
-31
lines changed

parsl/dataflow/dflow.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,6 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
547547
if not task_record['app_fu'] == future:
548548
logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id))
549549

550-
self.memoizer.update_checkpoint(task_record)
551-
552550
self.wipe_task(task_id)
553551
return
554552

parsl/dataflow/memoization.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from parsl.dataflow.errors import BadCheckpoint
1616
from parsl.dataflow.taskrecord import TaskRecord
17-
from parsl.errors import ConfigurationError, InternalConsistencyError
17+
from parsl.errors import ConfigurationError
1818
from parsl.utils import Timer, get_all_checkpoints
1919

2020
logger = logging.getLogger(__name__)
@@ -289,9 +289,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
289289
def update_memo_result(self, task: TaskRecord, r: Any) -> None:
290290
self._update_memo(task)
291291

292+
if self.checkpoint_mode == 'task_exit':
293+
self.checkpoint(task=task, result=r)
294+
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
295+
# with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
296+
self.checkpointable_tasks.append(task)
297+
elif self.checkpoint_mode is None:
298+
pass
299+
else:
300+
assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
301+
292302
def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None:
293303
self._update_memo(task)
294304

305+
if self.checkpoint_mode == 'task_exit':
306+
self.checkpoint(task=task, exception=e)
307+
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
308+
# with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
309+
self.checkpointable_tasks.append(task)
310+
elif self.checkpoint_mode is None:
311+
pass
312+
else:
313+
assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
314+
295315
def _update_memo(self, task: TaskRecord) -> None:
296316
"""Updates the memoization lookup table with the result from a task.
297317
This doesn't move any values around but associates the memoization
@@ -383,18 +403,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
383403
else:
384404
return {}
385405

386-
def update_checkpoint(self, task_record: TaskRecord) -> None:
387-
if self.checkpoint_mode == 'task_exit':
388-
self.checkpoint(task=task_record)
389-
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
390-
with self.checkpoint_lock:
391-
self.checkpointable_tasks.append(task_record)
392-
elif self.checkpoint_mode is None:
393-
pass
394-
else:
395-
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")
396-
397-
def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
406+
# TODO: this call becomes even more multiplexed...
407+
# called with no parameters, we write out the task
408+
# called with a task record, we can now no longer expect to get the
409+
# result from the task record future, because it will not be
410+
# populated yet.
411+
# so then either we can an exception, or if exception is None, then
412+
# checkpoint result. it's possible that result can be None as a
413+
# real result: in the case that exception is None.
414+
# what a horrible API that needs refactoring...
415+
416+
def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[BaseException] = None, result: Any = None) -> None:
398417
"""Checkpoint the dfk incrementally to a checkpoint file.
399418
400419
When called with no argument, all tasks registered in self.checkpointable_tasks
@@ -414,11 +433,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
414433
"""
415434
with self.checkpoint_lock:
416435

417-
if task:
418-
checkpoint_queue = [task]
419-
else:
420-
checkpoint_queue = self.checkpointable_tasks
421-
422436
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
423437
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
424438

@@ -428,22 +442,47 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
428442
count = 0
429443

430444
with open(checkpoint_tasks, 'ab') as f:
431-
for task_record in checkpoint_queue:
432-
task_id = task_record['id']
433445

434-
app_fu = task_record['app_fu']
446+
if task:
447+
# TODO: refactor with below
448+
449+
task_id = task['id']
435450

436-
if app_fu.done() and app_fu.exception() is None:
437-
hashsum = task_record['hashsum']
451+
if exception is None:
452+
hashsum = task['hashsum']
438453
if not hashsum:
439-
continue
440-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
454+
pass # TODO: log an error? see below discussion
455+
else:
456+
t = {'hash': hashsum, 'exception': None, 'result': result}
457+
458+
# We are using pickle here since pickle dumps to a file in 'ab'
459+
# mode behave like a incremental log.
460+
pickle.dump(t, f)
461+
count += 1
441462

442-
# We are using pickle here since pickle dumps to a file in 'ab'
443-
# mode behave like a incremental log.
444-
pickle.dump(t, f)
445-
count += 1
446463
logger.debug("Task {} checkpointed".format(task_id))
464+
else:
465+
checkpoint_queue = self.checkpointable_tasks
466+
467+
for task_record in checkpoint_queue:
468+
task_id = task_record['id']
469+
470+
app_fu = task_record['app_fu']
471+
472+
assert app_fu.done(), "trying to checkpoint a task that is not done"
473+
474+
if app_fu.done() and app_fu.exception() is None:
475+
hashsum = task_record['hashsum']
476+
if not hashsum:
477+
continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
478+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
479+
480+
# We are using pickle here since pickle dumps to a file in 'ab'
481+
# mode behave like a incremental log.
482+
pickle.dump(t, f)
483+
count += 1
484+
485+
logger.debug("Task {} checkpointed".format(task_id))
447486

448487
self.checkpointed_tasks += count
449488

0 commit comments

Comments
 (0)