@@ -263,9 +263,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
263263 def update_memo_result (self , task : TaskRecord , r : Any ) -> None :
264264 self ._update_memo (task )
265265
266+ if self .checkpoint_mode == 'task_exit' :
267+ self .checkpoint (task = task , result = r )
268+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
269+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
270+ self .checkpointable_tasks .append (task )
271+ elif self .checkpoint_mode is None :
272+ pass
273+ else :
274+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
275+
266276 def update_memo_exception (self , task : TaskRecord , e : BaseException ) -> None :
267277 self ._update_memo (task )
268278
279+ if self .checkpoint_mode == 'task_exit' :
280+ self .checkpoint (task = task , exception = e )
281+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
282+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
283+ self .checkpointable_tasks .append (task )
284+ elif self .checkpoint_mode is None :
285+ pass
286+ else :
287+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
288+
269289 def _update_memo (self , task : TaskRecord ) -> None :
270290 """Updates the memoization lookup table with the result from a task.
271291 This doesn't move any values around but associates the memoization
@@ -357,7 +377,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
357377 else :
358378 return {}
359379
360- def checkpoint (self , * , task : Optional [TaskRecord ] = None ) -> None :
380+ # TODO: this call becomes even more multiplexed...
381+ # called with no parameters, we write out the task
382+ # called with a task record, we can now no longer expect to get the
383+ # result from the task record future, because it will not be
384+ # populated yet.
385+ # so then either we can an exception, or if exception is None, then
386+ # checkpoint result. it's possible that result can be None as a
387+ # real result: in the case that exception is None.
388+ # what a horrible API that needs refactoring...
389+
390+ def checkpoint (self , * , task : Optional [TaskRecord ] = None , exception : Optional [BaseException ] = None , result : Any = None ) -> None :
361391 """Checkpoint the dfk incrementally to a checkpoint file.
362392
363393 When called with no argument, all tasks registered in self.checkpointable_tasks
@@ -377,11 +407,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
377407 """
378408 with self .checkpoint_lock :
379409
380- if task :
381- checkpoint_queue = [task ]
382- else :
383- checkpoint_queue = self .checkpointable_tasks
384-
385410 checkpoint_dir = '{0}/checkpoint' .format (self .run_dir )
386411 checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
387412
@@ -391,22 +416,47 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
391416 count = 0
392417
393418 with open (checkpoint_tasks , 'ab' ) as f :
394- for task_record in checkpoint_queue :
395- task_id = task_record ['id' ]
396419
397- app_fu = task_record ['app_fu' ]
420+ if task :
421+ # TODO: refactor with below
398422
399- if app_fu .done () and app_fu .exception () is None :
400- hashsum = task_record ['hashsum' ]
423+ task_id = task ['id' ]
424+
425+ if exception is None :
426+ hashsum = task ['hashsum' ]
401427 if not hashsum :
402- continue
403- t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
428+ pass # TODO: log an error? see below discussion
429+ else :
430+ t = {'hash' : hashsum , 'exception' : None , 'result' : result }
431+
432+ # We are using pickle here since pickle dumps to a file in 'ab'
433+ # mode behave like a incremental log.
434+ pickle .dump (t , f )
435+ count += 1
404436
405- # We are using pickle here since pickle dumps to a file in 'ab'
406- # mode behave like a incremental log.
407- pickle .dump (t , f )
408- count += 1
409437 logger .debug ("Task {} checkpointed" .format (task_id ))
438+ else :
439+ checkpoint_queue = self .checkpointable_tasks
440+
441+ for task_record in checkpoint_queue :
442+ task_id = task_record ['id' ]
443+
444+ app_fu = task_record ['app_fu' ]
445+
446+ assert app_fu .done (), "trying to checkpoint a task that is not done"
447+
448+ if app_fu .done () and app_fu .exception () is None :
449+ hashsum = task_record ['hashsum' ]
450+ if not hashsum :
451+ continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
452+ t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
453+
454+ # We are using pickle here since pickle dumps to a file in 'ab'
455+ # mode behave like a incremental log.
456+ pickle .dump (t , f )
457+ count += 1
458+
459+ logger .debug ("Task {} checkpointed" .format (task_id ))
410460
411461 self .checkpointed_tasks += count
412462
0 commit comments