Skip to content

Commit 7a4b826

Browse files
authored
Load checkpoints in memoizer code, rather than DFK (#3759)
This is first of a set of PRs to move checkpoint handling code out of the DFK in eventually-pluggable memoizers. # Changed Behaviour none ## Type of change - Code maintenance/cleanup
1 parent 045a1a7 commit 7a4b826

File tree

2 files changed

+81
-75
lines changed

2 files changed

+81
-75
lines changed

parsl/dataflow/dflow.py

Lines changed: 5 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from parsl.data_provider.data_manager import DataManager
2929
from parsl.data_provider.files import File
3030
from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER
31-
from parsl.dataflow.errors import BadCheckpoint, DependencyError, JoinError
31+
from parsl.dataflow.errors import DependencyError, JoinError
3232
from parsl.dataflow.futures import AppFuture
3333
from parsl.dataflow.memoization import Memoizer
3434
from parsl.dataflow.rundirs import make_rundir
@@ -161,13 +161,13 @@ def __init__(self, config: Config) -> None:
161161
workflow_info))
162162

163163
if config.checkpoint_files is not None:
164-
checkpoints = self.load_checkpoints(config.checkpoint_files)
164+
checkpoint_files = config.checkpoint_files
165165
elif config.checkpoint_files is None and config.checkpoint_mode is not None:
166-
checkpoints = self.load_checkpoints(get_all_checkpoints(self.run_dir))
166+
checkpoint_files = get_all_checkpoints(self.run_dir)
167167
else:
168-
checkpoints = {}
168+
checkpoint_files = []
169169

170-
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint=checkpoints)
170+
self.memoizer = Memoizer(self, memoize=config.app_cache, checkpoint_files=checkpoint_files)
171171
self.checkpointed_tasks = 0
172172
self._checkpoint_timer = None
173173
self.checkpoint_mode = config.checkpoint_mode
@@ -1317,74 +1317,6 @@ def checkpoint(self, tasks: Optional[Sequence[TaskRecord]] = None) -> str:
13171317

13181318
return checkpoint_dir
13191319

1320-
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
1321-
"""Load a checkpoint file into a lookup table.
1322-
1323-
The data being loaded from the pickle file mostly contains input
1324-
attributes of the task: func, args, kwargs, env...
1325-
To simplify the check of whether the exact task has been completed
1326-
in the checkpoint, we hash these input params and use it as the key
1327-
for the memoized lookup table.
1328-
1329-
Args:
1330-
- checkpointDirs (list) : List of filepaths to checkpoints
1331-
Eg. ['runinfo/001', 'runinfo/002']
1332-
1333-
Returns:
1334-
- memoized_lookup_table (dict)
1335-
"""
1336-
memo_lookup_table = {}
1337-
1338-
for checkpoint_dir in checkpointDirs:
1339-
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
1340-
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
1341-
try:
1342-
with open(checkpoint_file, 'rb') as f:
1343-
while True:
1344-
try:
1345-
data = pickle.load(f)
1346-
# Copy and hash only the input attributes
1347-
memo_fu: Future = Future()
1348-
assert data['exception'] is None
1349-
memo_fu.set_result(data['result'])
1350-
memo_lookup_table[data['hash']] = memo_fu
1351-
1352-
except EOFError:
1353-
# Done with the checkpoint file
1354-
break
1355-
except FileNotFoundError:
1356-
reason = "Checkpoint file was not found: {}".format(
1357-
checkpoint_file)
1358-
logger.error(reason)
1359-
raise BadCheckpoint(reason)
1360-
except Exception:
1361-
reason = "Failed to load checkpoint: {}".format(
1362-
checkpoint_file)
1363-
logger.error(reason)
1364-
raise BadCheckpoint(reason)
1365-
1366-
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
1367-
len(memo_lookup_table.keys())))
1368-
return memo_lookup_table
1369-
1370-
@typeguard.typechecked
1371-
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
1372-
"""Load checkpoints from the checkpoint files into a dictionary.
1373-
1374-
The results are used to pre-populate the memoizer's lookup_table
1375-
1376-
Kwargs:
1377-
- checkpointDirs (list) : List of run folder to use as checkpoints
1378-
Eg. ['runinfo/001', 'runinfo/002']
1379-
1380-
Returns:
1381-
- dict containing, hashed -> future mappings
1382-
"""
1383-
if checkpointDirs:
1384-
return self._load_checkpoints(checkpointDirs)
1385-
else:
1386-
return {}
1387-
13881320
@staticmethod
13891321
def _log_std_streams(task_record: TaskRecord) -> None:
13901322
tid = task_record['id']

parsl/dataflow/memoization.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import hashlib
44
import logging
5+
import os
56
import pickle
67
from functools import lru_cache, singledispatch
7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
89

10+
import typeguard
11+
12+
from parsl.dataflow.errors import BadCheckpoint
913
from parsl.dataflow.taskrecord import TaskRecord
1014

1115
if TYPE_CHECKING:
@@ -146,7 +150,7 @@ class Memoizer:
146150
147151
"""
148152

149-
def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[str, Future[Any]] = {}):
153+
def __init__(self, dfk: DataFlowKernel, *, memoize: bool = True, checkpoint_files: Sequence[str]):
150154
"""Initialize the memoizer.
151155
152156
Args:
@@ -159,6 +163,8 @@ def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[s
159163
self.dfk = dfk
160164
self.memoize = memoize
161165

166+
checkpoint = self.load_checkpoints(checkpoint_files)
167+
162168
if self.memoize:
163169
logger.info("App caching initialized")
164170
self.memo_lookup_table = checkpoint
@@ -274,3 +280,71 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
274280
else:
275281
logger.debug(f"Storing app cache entry {task['hashsum']} with result from task {task_id}")
276282
self.memo_lookup_table[task['hashsum']] = r
283+
284+
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
285+
"""Load a checkpoint file into a lookup table.
286+
287+
The data being loaded from the pickle file mostly contains input
288+
attributes of the task: func, args, kwargs, env...
289+
To simplify the check of whether the exact task has been completed
290+
in the checkpoint, we hash these input params and use it as the key
291+
for the memoized lookup table.
292+
293+
Args:
294+
- checkpointDirs (list) : List of filepaths to checkpoints
295+
Eg. ['runinfo/001', 'runinfo/002']
296+
297+
Returns:
298+
- memoized_lookup_table (dict)
299+
"""
300+
memo_lookup_table = {}
301+
302+
for checkpoint_dir in checkpointDirs:
303+
logger.info("Loading checkpoints from {}".format(checkpoint_dir))
304+
checkpoint_file = os.path.join(checkpoint_dir, 'tasks.pkl')
305+
try:
306+
with open(checkpoint_file, 'rb') as f:
307+
while True:
308+
try:
309+
data = pickle.load(f)
310+
# Copy and hash only the input attributes
311+
memo_fu: Future = Future()
312+
assert data['exception'] is None
313+
memo_fu.set_result(data['result'])
314+
memo_lookup_table[data['hash']] = memo_fu
315+
316+
except EOFError:
317+
# Done with the checkpoint file
318+
break
319+
except FileNotFoundError:
320+
reason = "Checkpoint file was not found: {}".format(
321+
checkpoint_file)
322+
logger.error(reason)
323+
raise BadCheckpoint(reason)
324+
except Exception:
325+
reason = "Failed to load checkpoint: {}".format(
326+
checkpoint_file)
327+
logger.error(reason)
328+
raise BadCheckpoint(reason)
329+
330+
logger.info("Completed loading checkpoint: {0} with {1} tasks".format(checkpoint_file,
331+
len(memo_lookup_table.keys())))
332+
return memo_lookup_table
333+
334+
@typeguard.typechecked
335+
def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, Future]:
336+
"""Load checkpoints from the checkpoint files into a dictionary.
337+
338+
The results are used to pre-populate the memoizer's lookup_table
339+
340+
Kwargs:
341+
- checkpointDirs (list) : List of run folder to use as checkpoints
342+
Eg. ['runinfo/001', 'runinfo/002']
343+
344+
Returns:
345+
- dict containing, hashed -> future mappings
346+
"""
347+
if checkpointDirs:
348+
return self._load_checkpoints(checkpointDirs)
349+
else:
350+
return {}

0 commit comments

Comments
 (0)