@@ -383,8 +383,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A
383383 data = pickle .load (f )
384384 # Copy and hash only the input attributes
385385 memo_fu : Future = Future ()
386- assert data ['exception' ] is None
387- memo_fu .set_result (data ['result' ])
386+
387+ if data ['exception' ] is None :
388+ memo_fu .set_result (data ['result' ])
389+ else :
390+ assert data ['result' ] is None
391+ memo_fu .set_exception (data ['exception' ])
388392 memo_lookup_table [data ['hash' ]] = memo_fu
389393
390394 except EOFError :
@@ -467,20 +471,22 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[B
467471 # TODO: refactor with below
468472
469473 task_id = task ['id' ]
470-
471- if exception is None :
472- hashsum = task ['hashsum' ]
473- if not hashsum :
474- pass # TODO: log an error? see below discussion
475- else :
474+ hashsum = task ['hashsum' ]
475+ if not hashsum :
476+ pass # TODO: log an error? see below discussion
477+ else :
478+ if exception is None and self .filter_result_for_checkpoint (result ):
476479 t = {'hash' : hashsum , 'exception' : None , 'result' : result }
477-
478- # We are using pickle here since pickle dumps to a file in 'ab'
479- # mode behave like a incremental log.
480480 pickle .dump (t , f )
481481 count += 1
482-
483- logger .debug ("Task {} checkpointed" .format (task_id ))
482+ logger .debug ("Task {} checkpointed result" .format (task_id ))
483+ elif exception is not None and self .filter_exception_for_checkpoint (exception ):
484+ t = {'hash' : hashsum , 'exception' : exception , 'result' : None }
485+ pickle .dump (t , f )
486+ count += 1
487+ logger .debug ("Task {} checkpointed exception" .format (task_id ))
488+ else :
489+ pass # no checkpoint - maybe debug log? TODO
484490 else :
485491 checkpoint_queue = self .checkpointable_tasks
486492
@@ -491,18 +497,22 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[B
491497
492498 assert app_fu .done (), "trying to checkpoint a task that is not done"
493499
494- if app_fu .done () and app_fu .exception () is None :
495- hashsum = task_record ['hashsum' ]
496- if not hashsum :
497- continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
498- t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
500+ hashsum = task_record ['hashsum' ]
501+ if not hashsum :
502+ continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
499503
500- # We are using pickle here since pickle dumps to a file in 'ab'
501- # mode behave like a incremental log.
504+ if app_fu . exception () is None and self . filter_result_for_checkpoint ( app_fu . result ()):
505+ t = { 'hash' : hashsum , 'exception' : None , 'result' : app_fu . result ()}
502506 pickle .dump (t , f )
503507 count += 1
504-
505- logger .debug ("Task {} checkpointed" .format (task_id ))
508+ logger .debug ("Task {} checkpointed result" .format (task_id ))
509+ elif (e := app_fu .exception ()) is not None and self .filter_exception_for_checkpoint (e ):
510+ t = {'hash' : hashsum , 'exception' : app_fu .exception (), 'result' : None }
511+ pickle .dump (t , f )
512+ count += 1
513+ logger .debug ("Task {} checkpointed exception" .format (task_id ))
514+ else :
515+ pass # TODO: maybe log at debug level
506516
507517 self .checkpointed_tasks += count
508518
@@ -516,3 +526,11 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[B
516526
517527 if not task :
518528 self .checkpointable_tasks = []
529+
530+ def filter_result_for_checkpoint (self , result : Any ) -> bool :
531+ """Overridable method to decide if an task that ended with a successful result should be checkpointed"""
532+ return True
533+
534+ def filter_exception_for_checkpoint (self , exception : BaseException ) -> bool :
535+ """Overridable method to decide if an entry that ended with an exception should be checkpointed"""
536+ return False
0 commit comments