1414
1515from parsl .dataflow .errors import BadCheckpoint
1616from parsl .dataflow .taskrecord import TaskRecord
17- from parsl .errors import ConfigurationError , InternalConsistencyError
17+ from parsl .errors import ConfigurationError
1818from parsl .utils import Timer , get_all_checkpoints
1919
2020logger = logging .getLogger (__name__ )
@@ -286,9 +286,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
286286 def update_memo_result (self , task : TaskRecord , r : Any ) -> None :
287287 self ._update_memo (task )
288288
289+ if self .checkpoint_mode == 'task_exit' :
290+ self .checkpoint (task = task , result = r )
291+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
292+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
293+ self .checkpointable_tasks .append (task )
294+ elif self .checkpoint_mode is None :
295+ pass
296+ else :
297+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
298+
289299 def update_memo_exception (self , task : TaskRecord , e : BaseException ) -> None :
290300 self ._update_memo (task )
291301
302+ if self .checkpoint_mode == 'task_exit' :
303+ self .checkpoint (task = task , exception = e )
304+ elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
305+ # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock
306+ self .checkpointable_tasks .append (task )
307+ elif self .checkpoint_mode is None :
308+ pass
309+ else :
310+ assert False , "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization"
311+
292312 def _update_memo (self , task : TaskRecord ) -> None :
293313 """Updates the memoization lookup table with the result from a task.
294314 This doesn't move any values around but associates the memoization
@@ -380,18 +400,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
380400 else :
381401 return {}
382402
383- def update_checkpoint (self , task_record : TaskRecord ) -> None :
384- if self .checkpoint_mode == 'task_exit' :
385- self .checkpoint (task = task_record )
386- elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
387- with self .checkpoint_lock :
388- self .checkpointable_tasks .append (task_record )
389- elif self .checkpoint_mode is None :
390- pass
391- else :
392- raise InternalConsistencyError (f"Invalid checkpoint mode { self .checkpoint_mode } " )
393-
394- def checkpoint (self , * , task : Optional [TaskRecord ] = None ) -> None :
403+ # TODO: this call becomes even more multiplexed...
404+ # called with no parameters, we write out the task
405+ # called with a task record, we can now no longer expect to get the
406+ # result from the task record future, because it will not be
407+ # populated yet.
408+ # so then either we can an exception, or if exception is None, then
409+ # checkpoint result. it's possible that result can be None as a
410+ # real result: in the case that exception is None.
411+ # what a horrible API that needs refactoring...
412+
413+ def checkpoint (self , * , task : Optional [TaskRecord ] = None , exception : Optional [BaseException ] = None , result : Any = None ) -> None :
395414 """Checkpoint the dfk incrementally to a checkpoint file.
396415
397416 When called with no argument, all tasks registered in self.checkpointable_tasks
@@ -411,11 +430,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
411430 """
412431 with self .checkpoint_lock :
413432
414- if task :
415- checkpoint_queue = [task ]
416- else :
417- checkpoint_queue = self .checkpointable_tasks
418-
419433 checkpoint_dir = '{0}/checkpoint' .format (self .run_dir )
420434 checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
421435
@@ -425,22 +439,47 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None:
425439 count = 0
426440
427441 with open (checkpoint_tasks , 'ab' ) as f :
428- for task_record in checkpoint_queue :
429- task_id = task_record ['id' ]
430442
431- app_fu = task_record ['app_fu' ]
443+ if task :
444+ # TODO: refactor with below
445+
446+ task_id = task ['id' ]
432447
433- if app_fu . done () and app_fu . exception () is None :
434- hashsum = task_record ['hashsum' ]
448+ if exception is None :
449+ hashsum = task ['hashsum' ]
435450 if not hashsum :
436- continue
437- t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
451+ pass # TODO: log an error? see below discussion
452+ else :
453+ t = {'hash' : hashsum , 'exception' : None , 'result' : result }
454+
455+ # We are using pickle here since pickle dumps to a file in 'ab'
456+ # mode behave like a incremental log.
457+ pickle .dump (t , f )
458+ count += 1
438459
439- # We are using pickle here since pickle dumps to a file in 'ab'
440- # mode behave like a incremental log.
441- pickle .dump (t , f )
442- count += 1
443460 logger .debug ("Task {} checkpointed" .format (task_id ))
461+ else :
462+ checkpoint_queue = self .checkpointable_tasks
463+
464+ for task_record in checkpoint_queue :
465+ task_id = task_record ['id' ]
466+
467+ app_fu = task_record ['app_fu' ]
468+
469+ assert app_fu .done (), "trying to checkpoint a task that is not done"
470+
471+ if app_fu .done () and app_fu .exception () is None :
472+ hashsum = task_record ['hashsum' ]
473+ if not hashsum :
474+ continue # TODO: log an error? maybe some tasks don't have hashsums legitimately?
475+ t = {'hash' : hashsum , 'exception' : None , 'result' : app_fu .result ()}
476+
477+ # We are using pickle here since pickle dumps to a file in 'ab'
478+ # mode behave like a incremental log.
479+ pickle .dump (t , f )
480+ count += 1
481+
482+ logger .debug ("Task {} checkpointed" .format (task_id ))
444483
445484 self .checkpointed_tasks += count
446485
0 commit comments