1414
1515from parsl .dataflow .errors import BadCheckpoint
1616from parsl .dataflow .taskrecord import TaskRecord
17- from parsl .errors import ConfigurationError , InternalConsistencyError
17+ from parsl .errors import ConfigurationError
1818from parsl .utils import Timer , get_all_checkpoints
1919
2020logger = logging .getLogger (__name__ )
@@ -289,9 +289,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
289289 def update_memo_result (self , task : TaskRecord , r : Any ) -> None :
290290 self ._update_memo (task )
291291
292+ if self .checkpoint_mode == 'task_exit' :
293+ self .checkpoint (task = task , result = r )
294+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
295+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
296+ self .checkpointable_tasks .append (task )
297+ elif self .checkpoint_mode is None :
298+ pass
299+ else :
300+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
301+
292302 def update_memo_exception (self , task : TaskRecord , e : BaseException ) -> None :
293303 self ._update_memo (task )
294304
305+ if self .checkpoint_mode == 'task_exit' :
306+ self .checkpoint (task = task , exception = e )
307+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
308+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
309+ self .checkpointable_tasks .append (task )
310+ elif self .checkpoint_mode is None :
311+ pass
312+ else :
313+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
314+
295315 def _update_memo (self , task : TaskRecord ) -> None :
296316 """Updates the memoization lookup table with the result from a task.
297317 This doesn't move any values around but associates the memoization
@@ -383,18 +403,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
383403 else :
384404 return {}
385405
386- def update_checkpoint (self , task_record : TaskRecord ) -> None :
387- if self .checkpoint_mode == 'task_exit' :
388- self .checkpoint (task = task_record )
389- elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
390- with self .checkpoint_lock :
391- self .checkpointable_tasks .append (task_record )
392- elif self .checkpoint_mode is None :
393- pass
394- else :
395- raise InternalConsistencyError (f"Invalid checkpoint mode { self .checkpoint_mode } " )
396-
397- def checkpoint (self , * , task : Optional [TaskRecord ] = None ) -> None :
406+ # TODO: this call becomes even more multiplexed...
407+ # called with no parameters, we write out the task
408+ # called with a task record, we can now no longer expect to get the
409+ # result from the task record future, because it will not be
410+ # populated yet.
411+ # so then either we can an exception, or if exception is None, then
412+ # checkpoint result. it's possible that result can be None as a
413+ # real result: in the case that exception is None.
414+ # what a horrible API that needs refactoring...
415+
416+ def checkpoint (self , * , task : Optional [TaskRecord ] = None , exception : Optional [BaseException ] = None , result : Any = None ) -> None :
398417 """Checkpoint the dfk incrementally to a checkpoint file.
399418
400419 When called with no argument, all tasks registered in self.checkpointable_tasks
@@ -414,11 +433,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
414433 """
415434 with self .checkpoint_lock :
416435
417- if task :
418- checkpoint_queue = [task ]
419- else :
420- checkpoint_queue = self .checkpointable_tasks
421-
422436 checkpoint_dir = '{0}/checkpoint' .format (self .run_dir )
423437 checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
424438
@@ -428,22 +442,47 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
428442 count = 0
429443
430444 with open (checkpoint_tasks , 'ab' ) as f :
431- for task_record in checkpoint_queue :
432- task_id = task_record ['id' ]
433445
434- app_fu = task_record ['app_fu' ]
446+ if task :
447+ # TODO: refactor with below
448+
449+ task_id = task ['id' ]
435450
436- if app_fu . done () and app_fu . exception () is None :
437- hashsum = task_record ['hashsum' ]
451+ if exception is None :
452+ hashsum = task ['hashsum' ]
438453 if not hashsum :
439- continue
440- t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
454+ pass # TODO: log an error? see below discussion
455+ else :
456+ t = {'hash' : hashsum , 'exception' : None , 'result' : result }
457+
458+ # We are using pickle here since pickle dumps to a file in 'ab'
459+ # mode behave like a incremental log.
460+ pickle .dump (t , f )
461+ count += 1
441462
442- # We are using pickle here since pickle dumps to a file in 'ab'
443- # mode behave like a incremental log.
444- pickle .dump (t , f )
445- count += 1
446463 logger .debug ("Task {} checkpointed" .format (task_id ))
464+ else :
465+ checkpoint_queue = self .checkpointable_tasks
466+
467+ for task_record in checkpoint_queue :
468+ task_id = task_record ['id' ]
469+
470+ app_fu = task_record ['app_fu' ]
471+
472+ assert app_fu .done (), "trying to checkpoint a task that is not done"
473+
474+ if app_fu .done () and app_fu .exception () is None :
475+ hashsum = task_record ['hashsum' ]
476+ if not hashsum :
477+ continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
478+ t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
479+
480+ # We are using pickle here since pickle dumps to a file in 'ab'
481+ # mode behave like a incremental log.
482+ pickle .dump (t , f )
483+ count += 1
484+
485+ logger .debug ("Task {} checkpointed" .format (task_id ))
447486
448487 self .checkpointed_tasks += count
449488
0 commit comments