Skip to content

Commit 670f584

Browse files
committed
dev
1 parent 72bcd37 commit 670f584

File tree

2 files changed

+71
-64
lines changed

2 files changed

+71
-64
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import inspect
77
import logging
88
import os
9-
import pickle
109
import random
1110
import sys
1211
import threading
@@ -101,8 +100,6 @@ def __init__(self, config: Config) -> None:
101100

102101
logger.info("Parsl version: {}".format(get_version()))
103102

104-
self.checkpoint_lock = threading.Lock()
105-
106103
self.usage_tracker = UsageTracker(self)
107104
self.usage_tracker.send_start_message()
108105

@@ -176,7 +173,8 @@ def __init__(self, config: Config) -> None:
176173
checkpoint_files = []
177174

178175
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
179-
self.checkpointed_tasks = 0
176+
self.memoizer.run_dir = self.run_dir
177+
180178
self._checkpoint_timer = None
181179
self.checkpoint_mode = config.checkpoint_mode
182180

@@ -569,7 +567,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
569567
# Do we need to checkpoint now, or queue for later,
570568
# or do nothing?
571569
if self.checkpoint_mode == 'task_exit':
572-
self.checkpoint(tasks=[task_record])
570+
self.memoizer.checkpoint(tasks=[task_record])
573571
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
574572
with self._modify_checkpointable_tasks_lock:
575573
self.checkpointable_tasks.append(task_record)
@@ -1210,7 +1208,7 @@ def cleanup(self) -> None:
12101208

12111209
# TODO: accesses to self.checkpointable_tasks should happen
12121210
# under a lock?
1213-
self.checkpoint(self.checkpointable_tasks)
1211+
self.memoizer.checkpoint(self.checkpointable_tasks)
12141212

12151213
if self._checkpoint_timer:
12161214
logger.info("Stopping checkpoint timer")
@@ -1274,66 +1272,9 @@ def cleanup(self) -> None:
12741272

12751273
def invoke_checkpoint(self) -> None:
12761274
with self._modify_checkpointable_tasks_lock:
1277-
self.checkpoint(self.checkpointable_tasks)
1275+
self.memoizer.checkpoint(self.checkpointable_tasks)
12781276
self.checkpointable_tasks = []
12791277

1280-
def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
1281-
"""Checkpoint the dfk incrementally to a checkpoint file.
1282-
1283-
When called, every task that has been completed yet not
1284-
checkpointed is checkpointed to a file.
1285-
1286-
Kwargs:
1287-
- tasks (List of task records) : List of task ids to checkpoint. Default=None
1288-
if set to None, we iterate over all tasks held by the DFK.
1289-
1290-
.. note::
1291-
Checkpointing only works if memoization is enabled
1292-
1293-
Returns:
1294-
Checkpoint dir if checkpoints were written successfully.
1295-
By default the checkpoints are written to the RUNDIR of the current
1296-
run under RUNDIR/checkpoints/tasks.pkl
1297-
"""
1298-
with self.checkpoint_lock:
1299-
checkpoint_queue = tasks
1300-
1301-
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
1302-
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
1303-
1304-
if not os.path.exists(checkpoint_dir):
1305-
os.makedirs(checkpoint_dir, exist_ok=True)
1306-
1307-
count = 0
1308-
1309-
with open(checkpoint_tasks, 'ab') as f:
1310-
for task_record in checkpoint_queue:
1311-
task_id = task_record['id']
1312-
1313-
app_fu = task_record['app_fu']
1314-
1315-
if app_fu.done() and app_fu.exception() is None:
1316-
hashsum = task_record['hashsum']
1317-
if not hashsum:
1318-
continue
1319-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
1320-
1321-
# We are using pickle here since pickle dumps to a file in 'ab'
1322-
# mode behave like a incremental log.
1323-
pickle.dump(t, f)
1324-
count += 1
1325-
logger.debug("Task {} checkpointed".format(task_id))
1326-
1327-
self.checkpointed_tasks += count
1328-
1329-
if count == 0:
1330-
if self.checkpointed_tasks == 0:
1331-
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
1332-
else:
1333-
logger.debug("No tasks checkpointed in this pass.")
1334-
else:
1335-
logger.info("Done checkpointing {} tasks".format(count))
1336-
13371278
@staticmethod
13381279
def _log_std_streams(task_record: TaskRecord) -> None:
13391280
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import pickle
7+
import threading
78
from functools import lru_cache, singledispatch
89
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
910

@@ -150,6 +151,8 @@ class Memoizer:
150151
151152
"""
152153

154+
run_dir: str
155+
153156
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
154157
"""Initialize the memoizer.
155158
@@ -163,6 +166,12 @@ def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_file
163166
self.dfk = dfk
164167
self.memoize = memoize
165168

169+
self.checkpointed_tasks = 0
170+
171+
self.checkpoint_lock = threading.Lock()
172+
173+
# TODO: we always load checkpoints even if we then discard them...
174+
# this is more obvious here, less obvious in previous Parsl...
166175
checkpoint = self.load_checkpoints(checkpoint_files)
167176

168177
if self.memoize:
@@ -348,3 +357,60 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
348357
return self._load_checkpoints(checkpointDirs)
349358
else:
350359
return {}
360+
361+
def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
362+
"""Checkpoint the dfk incrementally to a checkpoint file.
363+
364+
When called, every task that has been completed yet not
365+
checkpointed is checkpointed to a file.
366+
367+
Kwargs:
368+
- tasks (List of task records) : List of task ids to checkpoint. Default=None
369+
if set to None, we iterate over all tasks held by the DFK.
370+
371+
.. note::
372+
Checkpointing only works if memoization is enabled
373+
374+
Returns:
375+
Checkpoint dir if checkpoints were written successfully.
376+
By default the checkpoints are written to the RUNDIR of the current
377+
run under RUNDIR/checkpoints/tasks.pkl
378+
"""
379+
with self.checkpoint_lock:
380+
checkpoint_queue = tasks
381+
382+
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
383+
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
384+
385+
if not os.path.exists(checkpoint_dir):
386+
os.makedirs(checkpoint_dir, exist_ok=True)
387+
388+
count = 0
389+
390+
with open(checkpoint_tasks, 'ab') as f:
391+
for task_record in checkpoint_queue:
392+
task_id = task_record['id']
393+
394+
app_fu = task_record['app_fu']
395+
396+
if app_fu.done() and app_fu.exception() is None:
397+
hashsum = task_record['hashsum']
398+
if not hashsum:
399+
continue
400+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
401+
402+
# We are using pickle here since pickle dumps to a file in 'ab'
403+
# mode behave like a incremental log.
404+
pickle.dump(t, f)
405+
count += 1
406+
logger.debug("Task {} checkpointed".format(task_id))
407+
408+
self.checkpointed_tasks += count
409+
410+
if count == 0:
411+
if self.checkpointed_tasks == 0:
412+
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
413+
else:
414+
logger.debug("No tasks checkpointed in this pass.")
415+
else:
416+
logger.info("Done checkpointing {} tasks".format(count))

0 commit comments

Comments
 (0)