|
6 | 6 | import inspect |
7 | 7 | import logging |
8 | 8 | import os |
9 | | -import pickle |
10 | 9 | import random |
11 | 10 | import sys |
12 | 11 | import threading |
@@ -96,8 +95,6 @@ def __init__(self, config: Config) -> None: |
96 | 95 |
|
97 | 96 | logger.info("Parsl version: {}".format(get_version())) |
98 | 97 |
|
99 | | - self.checkpoint_lock = threading.Lock() |
100 | | - |
101 | 98 | self.usage_tracker = UsageTracker(self) |
102 | 99 | self.usage_tracker.send_start_message() |
103 | 100 |
|
@@ -168,7 +165,8 @@ def __init__(self, config: Config) -> None: |
168 | 165 | checkpoint_files = [] |
169 | 166 |
|
170 | 167 | self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files) |
171 | | - self.checkpointed_tasks = 0 |
| 168 | + self.memoizer.run_dir = self.run_dir |
| 169 | + |
172 | 170 | self._checkpoint_timer = None |
173 | 171 | self.checkpoint_mode = config.checkpoint_mode |
174 | 172 |
|
@@ -560,7 +558,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None: |
560 | 558 | # Do we need to checkpoint now, or queue for later, |
561 | 559 | # or do nothing? |
562 | 560 | if self.checkpoint_mode == 'task_exit': |
563 | | - self.checkpoint(tasks=[task_record]) |
| 561 | + self.memoizer.checkpoint(tasks=[task_record]) |
564 | 562 | elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): |
565 | 563 | with self._modify_checkpointable_tasks_lock: |
566 | 564 | self.checkpointable_tasks.append(task_record) |
@@ -1198,7 +1196,7 @@ def cleanup(self) -> None: |
1198 | 1196 |
|
1199 | 1197 | # TODO: accesses to self.checkpointable_tasks should happen |
1200 | 1198 | # under a lock? |
1201 | | - self.checkpoint(self.checkpointable_tasks) |
| 1199 | + self.memoizer.checkpoint(self.checkpointable_tasks) |
1202 | 1200 |
|
1203 | 1201 | if self._checkpoint_timer: |
1204 | 1202 | logger.info("Stopping checkpoint timer") |
@@ -1254,66 +1252,9 @@ def cleanup(self) -> None: |
1254 | 1252 |
|
1255 | 1253 | def invoke_checkpoint(self) -> None: |
1256 | 1254 | with self._modify_checkpointable_tasks_lock: |
1257 | | - self.checkpoint(self.checkpointable_tasks) |
| 1255 | + self.memoizer.checkpoint(self.checkpointable_tasks) |
1258 | 1256 | self.checkpointable_tasks = [] |
1259 | 1257 |
|
1260 | | - def checkpoint(self, tasks: Sequence[TaskRecord]) -> None: |
1261 | | - """Checkpoint the dfk incrementally to a checkpoint file. |
1262 | | -
|
1263 | | - When called, every task that has been completed yet not |
1264 | | - checkpointed is checkpointed to a file. |
1265 | | -
|
1266 | | - Kwargs: |
1267 | | - - tasks (List of task records) : List of task ids to checkpoint. Default=None |
1268 | | - if set to None, we iterate over all tasks held by the DFK. |
1269 | | -
|
1270 | | - .. note:: |
1271 | | - Checkpointing only works if memoization is enabled |
1272 | | -
|
1273 | | - Returns: |
1274 | | - Checkpoint dir if checkpoints were written successfully. |
1275 | | - By default the checkpoints are written to the RUNDIR of the current |
1276 | | - run under RUNDIR/checkpoints/tasks.pkl |
1277 | | - """ |
1278 | | - with self.checkpoint_lock: |
1279 | | - checkpoint_queue = tasks |
1280 | | - |
1281 | | - checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) |
1282 | | - checkpoint_tasks = checkpoint_dir + '/tasks.pkl' |
1283 | | - |
1284 | | - if not os.path.exists(checkpoint_dir): |
1285 | | - os.makedirs(checkpoint_dir, exist_ok=True) |
1286 | | - |
1287 | | - count = 0 |
1288 | | - |
1289 | | - with open(checkpoint_tasks, 'ab') as f: |
1290 | | - for task_record in checkpoint_queue: |
1291 | | - task_id = task_record['id'] |
1292 | | - |
1293 | | - app_fu = task_record['app_fu'] |
1294 | | - |
1295 | | - if app_fu.done() and app_fu.exception() is None: |
1296 | | - hashsum = task_record['hashsum'] |
1297 | | - if not hashsum: |
1298 | | - continue |
1299 | | - t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} |
1300 | | - |
1301 | | - # We are using pickle here since pickle dumps to a file in 'ab' |
1302 | | - # mode behave like a incremental log. |
1303 | | - pickle.dump(t, f) |
1304 | | - count += 1 |
1305 | | - logger.debug("Task {} checkpointed".format(task_id)) |
1306 | | - |
1307 | | - self.checkpointed_tasks += count |
1308 | | - |
1309 | | - if count == 0: |
1310 | | - if self.checkpointed_tasks == 0: |
1311 | | - logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled") |
1312 | | - else: |
1313 | | - logger.debug("No tasks checkpointed in this pass.") |
1314 | | - else: |
1315 | | - logger.info("Done checkpointing {} tasks".format(count)) |
1316 | | - |
1317 | 1258 | @staticmethod |
1318 | 1259 | def _log_std_streams(task_record: TaskRecord) -> None: |
1319 | 1260 | tid = task_record['id'] |
|
0 commit comments