Skip to content

Commit 67e6a52

Browse files
committed
loadchcekpoints in memoizer
1 parent 08865d7 commit 67e6a52

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from parsl.data_provider.data_manager import DataManager
3131
from parsl.data_provider.files import File
3232
from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER
33-
from parsl.dataflow.errors import BadCheckpoint, DependencyError, JoinError
33+
from parsl.dataflow.errors import DependencyError, JoinError
3434
from parsl.dataflow.futures import AppFuture
3535
from parsl.dataflow.memoization import Memoizer
3636
from parsl.dataflow.rundirs import make_rundir
@@ -166,13 +166,13 @@ def __init__(self, config: Config) -> None:
166166
workflow_info)
167167

168168
if config.checkpoint_files is not None:
169-
checkpoints = self.load_checkpoints(config.checkpoint_files)
169+
checkpoint_files = config.checkpoint_files
170170
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
171-
checkpoints = self.load_checkpoints(get_all_checkpoints(self.run_dir))
171+
checkpoint_files = get_all_checkpoints(self.run_dir)
172172
else:
173-
checkpoints = {}
173+
checkpoint_files = []
174174

175-
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint=checkpoints)
175+
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
176176
self.checkpointed_tasks = 0
177177
self._checkpoint_timer = None
178178
self.checkpoint_mode = config.checkpoint_mode
@@ -1400,74 +1400,6 @@ def checkpoint(self, tasks: Sequence[TaskRecord]) -> str:
14001400

14011401
return checkpoint_dir
14021402

1403-
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
1404-
"""Load a checkpoint file into a lookup table.
1405-
1406-
The data being loaded from the pickle file mostly contains input
1407-
attributes of the task: func, args, kwargs, env...
1408-
To simplify the check of whether the exact task has been completed
1409-
in the checkpoint, we hash these input params and use it as the key
1410-
for the memoized lookup table.
1411-
1412-
Args:
1413-
- checkpointDirs (list) : List of filepaths to checkpoints
1414-
Eg. ['runinfo/001', 'runinfo/002']
1415-
1416-
Returns:
1417-
- memoized_lookup_table (dict)
1418-
"""
1419-
memo_lookup_table = {}
1420-
1421-
for checkpoint_dir in checkpointDirs:
1422-
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
1423-
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
1424-
try:
1425-
with open(checkpoint_file, 'rb') as f:
1426-
while True:
1427-
try:
1428-
data = pickle.load(f)
1429-
# Copy and hash only the input attributes
1430-
memo_fu: Future = Future()
1431-
assert data['exception'] is None
1432-
memo_fu.set_result(data['result'])
1433-
memo_lookup_table[data['hash']] = memo_fu
1434-
1435-
except EOFError:
1436-
# Done with the checkpoint file
1437-
break
1438-
except FileNotFoundError:
1439-
reason = "Checkpoint file was not found: {}".format(
1440-
checkpoint_file)
1441-
logger.error(reason)
1442-
raise BadCheckpoint(reason)
1443-
except Exception:
1444-
reason = "Failed to load checkpoint: {}".format(
1445-
checkpoint_file)
1446-
logger.error(reason)
1447-
raise BadCheckpoint(reason)
1448-
1449-
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
1450-
len(memo_lookup_table.keys())))
1451-
return memo_lookup_table
1452-
1453-
@typeguard.typechecked
1454-
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
1455-
"""Load checkpoints from the checkpoint files into a dictionary.
1456-
1457-
The results are used to pre-populate the memoizer's lookup_table
1458-
1459-
Kwargs:
1460-
- checkpointDirs (list) : List of run folder to use as checkpoints
1461-
Eg. ['runinfo/001', 'runinfo/002']
1462-
1463-
Returns:
1464-
- dict containing, hashed -> future mappings
1465-
"""
1466-
if checkpointDirs:
1467-
return self._load_checkpoints(checkpointDirs)
1468-
else:
1469-
return {}
1470-
14711403
@staticmethod
14721404
def _log_std_streams(task_record: TaskRecord) -> None:
14731405
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import hashlib
44
import logging
5+
import os
56
import pickle
67
from functools import lru_cache, singledispatch
7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
89

10+
import typeguard
11+
12+
from parsl.dataflow.errors import BadCheckpoint
913
from parsl.dataflow.taskrecord import TaskRecord
1014

1115
if TYPE_CHECKING:
@@ -146,7 +150,7 @@ class Memoizer:
146150
147151
"""
148152

149-
def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[str, Future[Any]] = {}):
153+
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
150154
"""Initialize the memoizer.
151155
152156
Args:
@@ -159,6 +163,10 @@ def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[s
159163
self.dfk = dfk
160164
self.memoize = memoize
161165

166+
# TODO: we always load checkpoints even if we then discard them...
167+
# this is more obvious here, less obvious in previous Parsl...
168+
checkpoint = self.load_checkpoints(checkpoint_files)
169+
162170
if self.memoize:
163171
logger.info("App caching initialized")
164172
self.memo_lookup_table = checkpoint
@@ -274,3 +282,71 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
274282
else:
275283
logger.debug(f"Storing app cache entry {task['hashsum']} with result from task {task_id}")
276284
self.memo_lookup_table[task['hashsum']] = r
285+
286+
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
287+
"""Load a checkpoint file into a lookup table.
288+
289+
The data being loaded from the pickle file mostly contains input
290+
attributes of the task: func, args, kwargs, env...
291+
To simplify the check of whether the exact task has been completed
292+
in the checkpoint, we hash these input params and use it as the key
293+
for the memoized lookup table.
294+
295+
Args:
296+
- checkpointDirs (list) : List of filepaths to checkpoints
297+
Eg. ['runinfo/001', 'runinfo/002']
298+
299+
Returns:
300+
- memoized_lookup_table (dict)
301+
"""
302+
memo_lookup_table = {}
303+
304+
for checkpoint_dir in checkpointDirs:
305+
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
306+
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
307+
try:
308+
with open(checkpoint_file, 'rb') as f:
309+
while True:
310+
try:
311+
data = pickle.load(f)
312+
# Copy and hash only the input attributes
313+
memo_fu: Future = Future()
314+
assert data['exception'] is None
315+
memo_fu.set_result(data['result'])
316+
memo_lookup_table[data['hash']] = memo_fu
317+
318+
except EOFError:
319+
# Done with the checkpoint file
320+
break
321+
except FileNotFoundError:
322+
reason = "Checkpoint file was not found: {}".format(
323+
checkpoint_file)
324+
logger.error(reason)
325+
raise BadCheckpoint(reason)
326+
except Exception:
327+
reason = "Failed to load checkpoint: {}".format(
328+
checkpoint_file)
329+
logger.error(reason)
330+
raise BadCheckpoint(reason)
331+
332+
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
333+
len(memo_lookup_table.keys())))
334+
return memo_lookup_table
335+
336+
@typeguard.typechecked
337+
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
338+
"""Load checkpoints from the checkpoint files into a dictionary.
339+
340+
The results are used to pre-populate the memoizer's lookup_table
341+
342+
Kwargs:
343+
- checkpointDirs (list) : List of run folder to use as checkpoints
344+
Eg. ['runinfo/001', 'runinfo/002']
345+
346+
Returns:
347+
- dict containing, hashed -> future mappings
348+
"""
349+
if checkpointDirs:
350+
return self._load_checkpoints(checkpointDirs)
351+
else:
352+
return {}

0 commit comments

Comments
 (0)