@@ -179,6 +179,8 @@ def __init__(self, config: Config) -> None:
179179 self .checkpointed_tasks = 0
180180 self ._checkpoint_timer = None
181181 self .checkpoint_mode = config .checkpoint_mode
182+
183+ self ._modify_checkpointable_tasks_lock = threading .Lock ()
182184 self .checkpointable_tasks : List [TaskRecord ] = []
183185
184186 # this must be set before executors are added since add_executors calls
@@ -204,7 +206,7 @@ def __init__(self, config: Config) -> None:
204206 except Exception :
205207 raise ConfigurationError ("invalid checkpoint_period provided: {0} expected HH:MM:SS" .format (config .checkpoint_period ))
206208 checkpoint_period = (h * 3600 ) + (m * 60 ) + s
207- self ._checkpoint_timer = Timer (self .checkpoint , interval = checkpoint_period , name = "Checkpoint" )
209+ self ._checkpoint_timer = Timer (self .invoke_checkpoint , interval = checkpoint_period , name = "Checkpoint" )
208210
209211 self .task_count = 0
210212 self .tasks : Dict [int , TaskRecord ] = {}
@@ -569,7 +571,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
569571 if self .checkpoint_mode == 'task_exit' :
570572 self .checkpoint (tasks = [task_record ])
571573 elif self .checkpoint_mode in ('manual' , 'periodic' , 'dfk_exit' ):
572- with self .checkpoint_lock :
574+ with self ._modify_checkpointable_tasks_lock :
573575 self .checkpointable_tasks .append (task_record )
574576 elif self .checkpoint_mode is None :
575577 pass
@@ -1205,7 +1207,10 @@ def cleanup(self) -> None:
12051207 # Checkpointing takes priority over the rest of the tasks
12061208 # checkpoint if any valid checkpoint method is specified
12071209 if self .checkpoint_mode is not None :
1208- self .checkpoint ()
1210+
1211+ # TODO: accesses to self.checkpointable_tasks should happen
1212+ # under a lock?
1213+ self .checkpoint (self .checkpointable_tasks )
12091214
12101215 if self ._checkpoint_timer :
12111216 logger .info ("Stopping checkpoint timer" )
@@ -1267,7 +1272,12 @@ def cleanup(self) -> None:
12671272 # should still see it.
12681273 logger .info ("DFK cleanup complete" )
12691274
1270- def checkpoint (self , tasks : Optional [Sequence [TaskRecord ]] = None ) -> None :
1275+ def invoke_checkpoint (self ) -> None :
1276+ with self ._modify_checkpointable_tasks_lock :
1277+ self .checkpoint (self .checkpointable_tasks )
1278+ self .checkpointable_tasks = []
1279+
1280+ def checkpoint (self , tasks : Sequence [TaskRecord ]) -> None :
12711281 """Checkpoint the dfk incrementally to a checkpoint file.
12721282
12731283 When called, every task that has been completed yet not
@@ -1286,11 +1296,7 @@ def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> None:
12861296 run under RUNDIR/checkpoints/tasks.pkl
12871297 """
12881298 with self .checkpoint_lock :
1289- if tasks :
1290- checkpoint_queue = tasks
1291- else :
1292- checkpoint_queue = self .checkpointable_tasks
1293- self .checkpointable_tasks = []
1299+ checkpoint_queue = tasks
12941300
12951301 checkpoint_dir = '{0}/checkpoint' .format (self .run_dir )
12961302 checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
0 commit comments