Skip to content

Commit ae375dd

Browse files
committed
make checkpoint move away from always using futures for result value
and happen in update memo
1 parent 342e2c8 commit ae375dd

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

parsl/dataflow/dflow.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -560,19 +560,6 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
560560
if not task_record['app_fu'] == future:
561561
logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id))
562562

563-
# Cover all checkpointing cases here:
564-
# Do we need to checkpoint now, or queue for later,
565-
# or do nothing?
566-
if self.checkpoint_mode == 'task_exit':
567-
self.memoizer.checkpoint(task=task_record)
568-
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
569-
with self._modify_checkpointable_tasks_lock:
570-
self.memoizer.checkpointable_tasks.append(task_record)
571-
elif self.checkpoint_mode is None:
572-
pass
573-
else:
574-
raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}")
575-
576563
self.wipe_task(task_id)
577564
return
578565

parsl/dataflow/memoization.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
263263
def update_memo_result(self, task: TaskRecord, r: Any) -> None:
264264
self._update_memo(task)
265265

266+
if self.checkpoint_mode == 'task_exit':
267+
self.checkpoint(task=task, result=r)
268+
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
269+
# with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
270+
self.checkpointable_tasks.append(task)
271+
elif self.checkpoint_mode is None:
272+
pass
273+
else:
274+
assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
275+
266276
def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None:
267277
self._update_memo(task)
268278

279+
if self.checkpoint_mode == 'task_exit':
280+
self.checkpoint(task=task, exception=e)
281+
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
282+
# with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
283+
self.checkpointable_tasks.append(task)
284+
elif self.checkpoint_mode is None:
285+
pass
286+
else:
287+
assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
288+
269289
def _update_memo(self, task: TaskRecord) -> None:
270290
"""Updates the memoization lookup table with the result from a task.
271291
This doesn't move any values around but associates the memoization
@@ -357,7 +377,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
357377
else:
358378
return {}
359379

360-
def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
380+
# TODO: this call becomes even more multiplexed...
381+
# called with no parameters, we write out the task
382+
# called with a task record, we can now no longer expect to get the
383+
# result from the task record future, because it will not be
384+
# populated yet.
385+
# so then either we can an exception, or if exception is None, then
386+
# checkpoint result. it's possible that result can be None as a
387+
# real result: in the case that exception is None.
388+
# what a horrible API that needs refactoring...
389+
390+
def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[BaseException] = None, result: Any = None) -> None:
361391
"""Checkpoint the dfk incrementally to a checkpoint file.
362392
363393
When called with no argument, all tasks registered in self.checkpointable_tasks
@@ -377,11 +407,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
377407
"""
378408
with self.checkpoint_lock:
379409

380-
if task:
381-
checkpoint_queue = [task]
382-
else:
383-
checkpoint_queue = self.checkpointable_tasks
384-
385410
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
386411
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
387412

@@ -391,22 +416,47 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
391416
count = 0
392417

393418
with open(checkpoint_tasks, 'ab') as f:
394-
for task_record in checkpoint_queue:
395-
task_id = task_record['id']
396419

397-
app_fu = task_record['app_fu']
420+
if task:
421+
# TODO: refactor with below
398422

399-
if app_fu.done() and app_fu.exception() is None:
400-
hashsum = task_record['hashsum']
423+
task_id = task['id']
424+
425+
if exception is None:
426+
hashsum = task['hashsum']
401427
if not hashsum:
402-
continue
403-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
428+
pass # TODO: log an error? see below discussion
429+
else:
430+
t = {'hash': hashsum, 'exception': None, 'result': result}
431+
432+
# We are using pickle here since pickle dumps to a file in 'ab'
433+
# mode behave like a incremental log.
434+
pickle.dump(t, f)
435+
count += 1
404436

405-
# We are using pickle here since pickle dumps to a file in 'ab'
406-
# mode behave like a incremental log.
407-
pickle.dump(t, f)
408-
count += 1
409437
logger.debug("Task {} checkpointed".format(task_id))
438+
else:
439+
checkpoint_queue = self.checkpointable_tasks
440+
441+
for task_record in checkpoint_queue:
442+
task_id = task_record['id']
443+
444+
app_fu = task_record['app_fu']
445+
446+
assert app_fu.done(), "trying to checkpoint a task that is not done"
447+
448+
if app_fu.done() and app_fu.exception() is None:
449+
hashsum = task_record['hashsum']
450+
if not hashsum:
451+
continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
452+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
453+
454+
# We are using pickle here since pickle dumps to a file in 'ab'
455+
# mode behave like a incremental log.
456+
pickle.dump(t, f)
457+
count += 1
458+
459+
logger.debug("Task {} checkpointed".format(task_id))
410460

411461
self.checkpointed_tasks += count
412462

0 commit comments

Comments
 (0)