Skip to content

Commit c8f6b14

Browse files
committed
dev
1 parent 3824e27 commit c8f6b14

File tree

2 files changed

+78
-73
lines changed

2 files changed

+78
-73
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import os
99
import pathlib
10-
import pickle
1110
import random
1211
import sys
1312
import threading
@@ -99,8 +98,6 @@ def __init__(self, config: Config) -> None:
9998

10099
logger.info("Parsl version: {}".format(get_version()))
101100

102-
self.checkpoint_lock = threading.Lock()
103-
104101
self.usage_tracker = UsageTracker(self)
105102
self.usage_tracker.send_start_message()
106103

@@ -177,7 +174,8 @@ def __init__(self, config: Config) -> None:
177174
checkpoint_files = []
178175

179176
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
180-
self.checkpointed_tasks = 0
177+
self.memoizer.run_dir = self.run_dir
178+
181179
self._checkpoint_timer = None
182180
self.checkpoint_mode = config.checkpoint_mode
183181

@@ -575,7 +573,7 @@ def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None:
575573
# Do we need to checkpoint now, or queue for later,
576574
# or do nothing?
577575
if self.checkpoint_mode == 'task_exit':
578-
self.checkpoint(tasks=[task_record])
576+
self.memoizer.checkpoint(tasks=[task_record])
579577
elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'):
580578
with self._modify_checkpointable_tasks_lock:
581579
self.checkpointable_tasks.append(task_record)
@@ -1259,7 +1257,7 @@ def cleanup(self) -> None:
12591257

12601258
# TODO: accesses to self.checkpointable_tasks should happen
12611259
# under a lock?
1262-
self.checkpoint(self.checkpointable_tasks)
1260+
self.memoizer.checkpoint(self.checkpointable_tasks)
12631261

12641262
if self._checkpoint_timer:
12651263
logger.info("Stopping checkpoint timer")
@@ -1334,76 +1332,10 @@ def cleanup(self) -> None:
13341332

13351333
def invoke_checkpoint(self):
13361334
with self._modify_checkpointable_tasks_lock:
1337-
r = self.checkpoint(self.checkpointable_tasks)
1335+
r = self.memoizer.checkpoint(self.checkpointable_tasks)
13381336
self.checkpointable_tasks = []
13391337
return r
13401338

1341-
def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
1342-
"""Checkpoint the dfk incrementally to a checkpoint file.
1343-
1344-
When called, every task that has been completed yet not
1345-
checkpointed is checkpointed to a file.
1346-
1347-
Kwargs:
1348-
- tasks (List of task records) : List of task ids to checkpoint. Default=None
1349-
if set to None, we iterate over all tasks held by the DFK.
1350-
1351-
.. note::
1352-
Checkpointing only works if memoization is enabled
1353-
1354-
Returns:
1355-
Checkpoint dir if checkpoints were written successfully.
1356-
By default the checkpoints are written to the RUNDIR of the current
1357-
run under RUNDIR/checkpoints/{tasks.pkl, dfk.pkl}
1358-
"""
1359-
with self.checkpoint_lock:
1360-
checkpoint_queue = tasks
1361-
1362-
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
1363-
checkpoint_dfk = checkpoint_dir + '/dfk.pkl'
1364-
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
1365-
1366-
if not os.path.exists(checkpoint_dir):
1367-
os.makedirs(checkpoint_dir, exist_ok=True)
1368-
1369-
with open(checkpoint_dfk, 'wb') as f:
1370-
state = {'rundir': self.run_dir,
1371-
'task_count': self.task_count
1372-
}
1373-
pickle.dump(state, f)
1374-
1375-
count = 0
1376-
1377-
with open(checkpoint_tasks, 'ab') as f:
1378-
for task_record in checkpoint_queue:
1379-
task_id = task_record['id']
1380-
1381-
app_fu = task_record['app_fu']
1382-
1383-
if app_fu.done() and app_fu.exception() is None:
1384-
hashsum = task_record['hashsum']
1385-
if not hashsum:
1386-
continue
1387-
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
1388-
1389-
# We are using pickle here since pickle dumps to a file in 'ab'
1390-
# mode behave like a incremental log.
1391-
pickle.dump(t, f)
1392-
count += 1
1393-
logger.debug("Task {} checkpointed".format(task_id))
1394-
1395-
self.checkpointed_tasks += count
1396-
1397-
if count == 0:
1398-
if self.checkpointed_tasks == 0:
1399-
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
1400-
else:
1401-
logger.debug("No tasks checkpointed in this pass.")
1402-
else:
1403-
logger.info("Done checkpointing {} tasks".format(count))
1404-
1405-
return checkpoint_dir
1406-
14071339
@staticmethod
14081340
def _log_std_streams(task_record: TaskRecord) -> None:
14091341
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import pickle
7+
import threading
78
from functools import lru_cache, singledispatch
89
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
910

@@ -150,6 +151,8 @@ class Memoizer:
150151
151152
"""
152153

154+
run_dir: str
155+
153156
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
154157
"""Initialize the memoizer.
155158
@@ -163,6 +166,10 @@ def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_file
163166
self.dfk = dfk
164167
self.memoize = memoize
165168

169+
self.checkpointed_tasks = 0
170+
171+
self.checkpoint_lock = threading.Lock()
172+
166173
# TODO: we always load checkpoints even if we then discard them...
167174
# this is more obvious here, less obvious in previous Parsl...
168175
checkpoint = self.load_checkpoints(checkpoint_files)
@@ -350,3 +357,69 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str,
350357
return self._load_checkpoints(checkpointDirs)
351358
else:
352359
return {}
360+
361+
def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
362+
"""Checkpoint the dfk incrementally to a checkpoint file.
363+
364+
When called, every task that has been completed yet not
365+
checkpointed is checkpointed to a file.
366+
367+
Kwargs:
368+
- tasks (List of task records) : List of task ids to checkpoint. Default=None
369+
if set to None, we iterate over all tasks held by the DFK.
370+
371+
.. note::
372+
Checkpointing only works if memoization is enabled
373+
374+
Returns:
375+
Checkpoint dir if checkpoints were written successfully.
376+
By default the checkpoints are written to the RUNDIR of the current
377+
run under RUNDIR/checkpoints/{tasks.pkl, dfk.pkl}
378+
"""
379+
with self.checkpoint_lock:
380+
checkpoint_queue = tasks
381+
382+
checkpoint_dir = '{0}/checkpoint'.format(self.run_dir)
383+
checkpoint_dfk = checkpoint_dir + '/dfk.pkl'
384+
checkpoint_tasks = checkpoint_dir + '/tasks.pkl'
385+
386+
if not os.path.exists(checkpoint_dir):
387+
os.makedirs(checkpoint_dir, exist_ok=True)
388+
389+
with open(checkpoint_dfk, 'wb') as f:
390+
state = {'rundir': self.run_dir,
391+
# TODO: this isn't relevant to checkpointing? 'task_count': self.task_count
392+
}
393+
pickle.dump(state, f)
394+
395+
count = 0
396+
397+
with open(checkpoint_tasks, 'ab') as f:
398+
for task_record in checkpoint_queue:
399+
task_id = task_record['id']
400+
401+
app_fu = task_record['app_fu']
402+
403+
if app_fu.done() and app_fu.exception() is None:
404+
hashsum = task_record['hashsum']
405+
if not hashsum:
406+
continue
407+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
408+
409+
# We are using pickle here since pickle dumps to a file in 'ab'
410+
# mode behave like a incremental log.
411+
pickle.dump(t, f)
412+
count += 1
413+
logger.debug("Task {} checkpointed".format(task_id))
414+
415+
self.checkpointed_tasks += count
416+
417+
if count == 0:
418+
if self.checkpointed_tasks == 0:
419+
logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled")
420+
else:
421+
logger.debug("No tasks checkpointed in this pass.")
422+
else:
423+
logger.info("Done checkpointing {} tasks".format(count))
424+
425+
return checkpoint_dir

0 commit comments

Comments
 (0)