Skip to content

Commit 3824e27

Browse files
committed
loadchcekpoints in memoizer
1 parent 6737efc commit 3824e27

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from parsl.data_provider.data_manager import DataManager
3131
from parsl.data_provider.files import File
3232
from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER
33-
from parsl.dataflow.errors import BadCheckpoint, DependencyError, JoinError
33+
from parsl.dataflow.errors import DependencyError, JoinError
3434
from parsl.dataflow.futures import AppFuture
3535
from parsl.dataflow.memoization import Memoizer
3636
from parsl.dataflow.rundirs import make_rundir
@@ -170,13 +170,13 @@ def __init__(self, config: Config) -> None:
170170
workflow_info)
171171

172172
if config.checkpoint_files is not None:
173-
checkpoints = self.load_checkpoints(config.checkpoint_files)
173+
checkpoint_files = config.checkpoint_files
174174
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
175-
checkpoints = self.load_checkpoints(get_all_checkpoints(self.run_dir))
175+
checkpoint_files = get_all_checkpoints(self.run_dir)
176176
else:
177-
checkpoints = {}
177+
checkpoint_files = []
178178

179-
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint=checkpoints)
179+
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
180180
self.checkpointed_tasks = 0
181181
self._checkpoint_timer = None
182182
self.checkpoint_mode = config.checkpoint_mode
@@ -1404,74 +1404,6 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
14041404

14051405
return checkpoint_dir
14061406

1407-
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
1408-
"""Load a checkpoint file into a lookup table.
1409-
1410-
The data being loaded from the pickle file mostly contains input
1411-
attributes of the task: func, args, kwargs, env...
1412-
To simplify the check of whether the exact task has been completed
1413-
in the checkpoint, we hash these input params and use it as the key
1414-
for the memoized lookup table.
1415-
1416-
Args:
1417-
- checkpointDirs (list) : List of filepaths to checkpoints
1418-
Eg. ['runinfo/001', 'runinfo/002']
1419-
1420-
Returns:
1421-
- memoized_lookup_table (dict)
1422-
"""
1423-
memo_lookup_table = {}
1424-
1425-
for checkpoint_dir in checkpointDirs:
1426-
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
1427-
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
1428-
try:
1429-
with open(checkpoint_file, 'rb') as f:
1430-
while True:
1431-
try:
1432-
data = pickle.load(f)
1433-
# Copy and hash only the input attributes
1434-
memo_fu: Future = Future()
1435-
assert data['exception'] is None
1436-
memo_fu.set_result(data['result'])
1437-
memo_lookup_table[data['hash']] = memo_fu
1438-
1439-
except EOFError:
1440-
# Done with the checkpoint file
1441-
break
1442-
except FileNotFoundError:
1443-
reason = "Checkpoint file was not found: {}".format(
1444-
checkpoint_file)
1445-
logger.error(reason)
1446-
raise BadCheckpoint(reason)
1447-
except Exception:
1448-
reason = "Failed to load checkpoint: {}".format(
1449-
checkpoint_file)
1450-
logger.error(reason)
1451-
raise BadCheckpoint(reason)
1452-
1453-
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
1454-
len(memo_lookup_table.keys())))
1455-
return memo_lookup_table
1456-
1457-
@typeguard.typechecked
1458-
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
1459-
"""Load checkpoints from the checkpoint files into a dictionary.
1460-
1461-
The results are used to pre-populate the memoizer's lookup_table
1462-
1463-
Kwargs:
1464-
- checkpointDirs (list) : List of run folder to use as checkpoints
1465-
Eg. ['runinfo/001', 'runinfo/002']
1466-
1467-
Returns:
1468-
- dict containing, hashed -> future mappings
1469-
"""
1470-
if checkpointDirs:
1471-
return self._load_checkpoints(checkpointDirs)
1472-
else:
1473-
return {}
1474-
14751407
@staticmethod
14761408
def _log_std_streams(task_record: TaskRecord) -> None:
14771409
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import hashlib
44
import logging
5+
import os
56
import pickle
67
from functools import lru_cache, singledispatch
7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
89

10+
import typeguard
11+
12+
from parsl.dataflow.errors import BadCheckpoint
913
from parsl.dataflow.taskrecord import TaskRecord
1014

1115
if TYPE_CHECKING:
@@ -146,7 +150,7 @@ class Memoizer:
146150
147151
"""
148152

149-
def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[str, Future[Any]] = {}):
153+
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
150154
"""Initialize the memoizer.
151155
152156
Args:
@@ -159,6 +163,10 @@ def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[s
159163
self.dfk = dfk
160164
self.memoize = memoize
161165

166+
# TODO: we always load checkpoints even if we then discard them...
167+
# this is more obvious here, less obvious in previous Parsl...
168+
checkpoint = self.load_checkpoints(checkpoint_files)
169+
162170
if self.memoize:
163171
logger.info("App caching initialized")
164172
self.memo_lookup_table = checkpoint
@@ -274,3 +282,71 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
274282
else:
275283
logger.debug(f"Storing app cache entry {task['hashsum']} with result from task {task_id}")
276284
self.memo_lookup_table[task['hashsum']] = r
285+
286+
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
287+
"""Load a checkpoint file into a lookup table.
288+
289+
The data being loaded from the pickle file mostly contains input
290+
attributes of the task: func, args, kwargs, env...
291+
To simplify the check of whether the exact task has been completed
292+
in the checkpoint, we hash these input params and use it as the key
293+
for the memoized lookup table.
294+
295+
Args:
296+
- checkpointDirs (list) : List of filepaths to checkpoints
297+
Eg. ['runinfo/001', 'runinfo/002']
298+
299+
Returns:
300+
- memoized_lookup_table (dict)
301+
"""
302+
memo_lookup_table = {}
303+
304+
for checkpoint_dir in checkpointDirs:
305+
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
306+
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
307+
try:
308+
with open(checkpoint_file, 'rb') as f:
309+
while True:
310+
try:
311+
data = pickle.load(f)
312+
# Copy and hash only the input attributes
313+
memo_fu: Future = Future()
314+
assert data['exception'] is None
315+
memo_fu.set_result(data['result'])
316+
memo_lookup_table[data['hash']] = memo_fu
317+
318+
except EOFError:
319+
# Done with the checkpoint file
320+
break
321+
except FileNotFoundError:
322+
reason = "Checkpoint file was not found: {}".format(
323+
checkpoint_file)
324+
logger.error(reason)
325+
raise BadCheckpoint(reason)
326+
except Exception:
327+
reason = "Failed to load checkpoint: {}".format(
328+
checkpoint_file)
329+
logger.error(reason)
330+
raise BadCheckpoint(reason)
331+
332+
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
333+
len(memo_lookup_table.keys())))
334+
return memo_lookup_table
335+
336+
@typeguard.typechecked
337+
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
338+
"""Load checkpoints from the checkpoint files into a dictionary.
339+
340+
The results are used to pre-populate the memoizer's lookup_table
341+
342+
Kwargs:
343+
- checkpointDirs (list) : List of run folder to use as checkpoints
344+
Eg. ['runinfo/001', 'runinfo/002']
345+
346+
Returns:
347+
- dict containing, hashed -> future mappings
348+
"""
349+
if checkpointDirs:
350+
return self._load_checkpoints(checkpointDirs)
351+
else:
352+
return {}

0 commit comments

Comments
 (0)