Skip to content

Commit 303a2d8

Browse files
committed
dev
1 parent 98e9637 commit 303a2d8

File tree

2 files changed

+73
-64
lines changed

2 files changed

+73
-64
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import inspect
77
import logging
88
import os
9-
import pickle
109
import random
1110
import sys
1211
import threading
@@ -96,8 +95,6 @@ def __init__(self, config: Config) -> None:
9695

9796
logger.info("Parsl version: {}".format(get_version()))
9897

99-
self.checkpoint_lock = threading.Lock()
100-
10198
self.usage_tracker = UsageTracker(self)
10299
self.usage_tracker.send_start_message()
103100

@@ -168,7 +165,8 @@ def __init__(self, config: Config) -> None:
168165
checkpoint_files = []
169166

170167
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
171-
self.checkpointed_tasks = 0
168+
self.memoizer.run_dir = self.run_dir
169+
172170
self._checkpoint_timer = None
173171
self.checkpoint_mode = config.checkpoint_mode
174172

@@ -560,7 +558,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
560558
# Do we need to checkpoint now, or queue for later,
561559
# or do nothing?
562560
if self.checkpoint_mode == 'task_exit':
563-
self.checkpoint(tasks=[task_record])
561+
self.memoizer.checkpoint(tasks=[task_record])
564562
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
565563
with self._modify_checkpointable_tasks_lock:
566564
self.checkpointable_tasks.append(task_record)
@@ -1198,7 +1196,7 @@ def cleanup(self) -> None:
11981196

11991197
# TODO: accesses to self.checkpointable_tasks should happen
12001198
# under a lock?
1201-
self.checkpoint(self.checkpointable_tasks)
1199+
self.memoizer.checkpoint(self.checkpointable_tasks)
12021200

12031201
if self._checkpoint_timer:
12041202
logger.info("Stopping checkpoint timer")
@@ -1254,66 +1252,9 @@ def cleanup(self) -> None:
12541252

12551253
def invoke_checkpoint(self) -> None:
12561254
with self._modify_checkpointable_tasks_lock:
1257-
self.checkpoint(self.checkpointable_tasks)
1255+
self.memoizer.checkpoint(self.checkpointable_tasks)
12581256
self.checkpointable_tasks = []
12591257

1260-
def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
1261-
"""Checkpoint the dfk incrementally to a checkpoint file.
1262-
1263-
When called, every task that has been completed yet not
1264-
checkpointed is checkpointed to a file.
1265-
1266-
Kwargs:
1267-
- tasks (List of task records) : List of task ids to checkpoint. Default=None
1268-
if set to None, we iterate over all tasks held by the DFK.
1269-
1270-
.. note::
1271-
Checkpointing only works if memoization is enabled
1272-
1273-
Returns:
1274-
Checkpoint dir if checkpoints were written successfully.
1275-
By default the checkpoints are written to the RUNDIR of the current
1276-
run under RUNDIR/checkpoints/tasks.pkl
1277-
"""
1278-
with self.checkpoint_lock:
1279-
checkpoint_queue = tasks
1280-
1281-
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
1282-
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
1283-
1284-
if not os.path.exists(checkpoint_dir):
1285-
os.makedirs(checkpoint_dir, exist_ok=True)
1286-
1287-
count = 0
1288-
1289-
with open(checkpoint_tasks, 'ab') as f:
1290-
for task_record in checkpoint_queue:
1291-
task_id = task_record['id']
1292-
1293-
app_fu = task_record['app_fu']
1294-
1295-
if app_fu.done() and app_fu.exception() is None:
1296-
hashsum = task_record['hashsum']
1297-
if not hashsum:
1298-
continue
1299-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
1300-
1301-
# We are using pickle here since pickle dumps to a file in 'ab'
1302-
# mode behave like a incremental log.
1303-
pickle.dump(t, f)
1304-
count += 1
1305-
logger.debug("Task {} checkpointed".format(task_id))
1306-
1307-
self.checkpointed_tasks += count
1308-
1309-
if count == 0:
1310-
if self.checkpointed_tasks == 0:
1311-
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
1312-
else:
1313-
logger.debug("No tasks checkpointed in this pass.")
1314-
else:
1315-
logger.info("Done checkpointing {} tasks".format(count))
1316-
13171258
@staticmethod
13181259
def _log_std_streams(task_record: TaskRecord) -> None:
13191260
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import pickle
7+
import threading
78
from functools import lru_cache, singledispatch
89
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
910

@@ -150,6 +151,8 @@ class Memoizer:
150151
151152
"""
152153

154+
run_dir: str
155+
153156
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
154157
"""Initialize the memoizer.
155158
@@ -163,6 +166,12 @@ def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_file
163166
self.dfk = dfk
164167
self.memoize = memoize
165168

169+
self.checkpointed_tasks = 0
170+
171+
self.checkpoint_lock = threading.Lock()
172+
173+
# TODO: we always load checkpoints even if we then discard them...
174+
# this is more obvious here, less obvious in previous Parsl...
166175
checkpoint = self.load_checkpoints(checkpoint_files)
167176

168177
if self.memoize:
@@ -348,3 +357,62 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
348357
return self._load_checkpoints(checkpointDirs)
349358
else:
350359
return {}
360+
361+
def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
362+
"""Checkpoint the dfk incrementally to a checkpoint file.
363+
364+
When called, every task that has been completed yet not
365+
checkpointed is checkpointed to a file.
366+
367+
Kwargs:
368+
- tasks (List of task records) : List of task ids to checkpoint. Default=None
369+
if set to None, we iterate over all tasks held by the DFK.
370+
371+
.. note::
372+
Checkpointing only works if memoization is enabled
373+
374+
Returns:
375+
Checkpoint dir if checkpoints were written successfully.
376+
By default the checkpoints are written to the RUNDIR of the current
377+
run under RUNDIR/checkpoints/tasks.pkl
378+
"""
379+
with self.checkpoint_lock:
380+
checkpoint_queue = tasks
381+
382+
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
383+
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
384+
385+
if not os.path.exists(checkpoint_dir):
386+
os.makedirs(checkpoint_dir, exist_ok=True)
387+
388+
count = 0
389+
390+
with open(checkpoint_tasks, 'ab') as f:
391+
for task_record in checkpoint_queue:
392+
task_id = task_record['id']
393+
394+
app_fu = task_record['app_fu']
395+
396+
if app_fu.done() and app_fu.exception() is None:
397+
hashsum = task_record['hashsum']
398+
if not hashsum:
399+
continue
400+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
401+
402+
# We are using pickle here since pickle dumps to a file in 'ab'
403+
# mode behave like a incremental log.
404+
pickle.dump(t, f)
405+
count += 1
406+
logger.debug("Task {} checkpointed".format(task_id))
407+
408+
self.checkpointed_tasks += count
409+
410+
if count == 0:
411+
if self.checkpointed_tasks == 0:
412+
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
413+
else:
414+
logger.debug("No tasks checkpointed in this pass.")
415+
else:
416+
logger.info("Done checkpointing {} tasks".format(count))
417+
418+
return checkpoint_dir

0 commit comments

Comments
 (0)