22
33import hashlib
44import logging
5+ import os
56import pickle
67from functools import lru_cache , singledispatch
7- from typing import TYPE_CHECKING , Any , Dict , List , Optional
8+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Sequence
89
10+ import typeguard
11+
12+ from parsl .dataflow .errors import BadCheckpoint
913from parsl .dataflow .taskrecord import TaskRecord
1014
1115if TYPE_CHECKING :
@@ -146,7 +150,7 @@ class Memoizer:
146150
147151 """
148152
149- def __init__ (self , dfk : DataFlowKernel , memoize : bool = True , checkpoint : Dict [str , Future [ Any ]] = {} ):
153+ def __init__ (self , dfk : DataFlowKernel , * , memoize : bool = True , checkpoint_files : Sequence [str ] ):
150154 """Initialize the memoizer.
151155
152156 Args:
@@ -159,6 +163,8 @@ def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[s
159163 self .dfk = dfk
160164 self .memoize = memoize
161165
166+ checkpoint = self .load_checkpoints (checkpoint_files )
167+
162168 if self .memoize :
163169 logger .info ("App caching initialized" )
164170 self .memo_lookup_table = checkpoint
@@ -274,3 +280,71 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
274280 else :
275281 logger .debug (f"Storing app cache entry { task ['hashsum' ]} with result from task { task_id } " )
276282 self .memo_lookup_table [task ['hashsum' ]] = r
283+
284+ def _load_checkpoints (self , checkpointDirs : Sequence [str ]) -> Dict [str , Future [Any ]]:
285+ """Load a checkpoint file into a lookup table.
286+
287+ The data being loaded from the pickle file mostly contains input
288+ attributes of the task: func, args, kwargs, env...
289+ To simplify the check of whether the exact task has been completed
290+ in the checkpoint, we hash these input params and use it as the key
291+ for the memoized lookup table.
292+
293+ Args:
294+ - checkpointDirs (list) : List of filepaths to checkpoints
295+ Eg. ['runinfo/001', 'runinfo/002']
296+
297+ Returns:
298+ - memoized_lookup_table (dict)
299+ """
300+ memo_lookup_table = {}
301+
302+ for checkpoint_dir in checkpointDirs :
303+ logger .info ("Loading checkpoints from {}" .format (checkpoint_dir ))
304+ checkpoint_file = os .path .join (checkpoint_dir , 'tasks.pkl' )
305+ try :
306+ with open (checkpoint_file , 'rb' ) as f :
307+ while True :
308+ try :
309+ data = pickle .load (f )
310+ # Copy and hash only the input attributes
311+ memo_fu : Future = Future ()
312+ assert data ['exception' ] is None
313+ memo_fu .set_result (data ['result' ])
314+ memo_lookup_table [data ['hash' ]] = memo_fu
315+
316+ except EOFError :
317+ # Done with the checkpoint file
318+ break
319+ except FileNotFoundError :
320+ reason = "Checkpoint file was not found: {}" .format (
321+ checkpoint_file )
322+ logger .error (reason )
323+ raise BadCheckpoint (reason )
324+ except Exception :
325+ reason = "Failed to load checkpoint: {}" .format (
326+ checkpoint_file )
327+ logger .error (reason )
328+ raise BadCheckpoint (reason )
329+
330+ logger .info ("Completed loading checkpoint: {0} with {1} tasks" .format (checkpoint_file ,
331+ len (memo_lookup_table .keys ())))
332+ return memo_lookup_table
333+
334+ @typeguard .typechecked
335+ def load_checkpoints (self , checkpointDirs : Optional [Sequence [str ]]) -> Dict [str , Future ]:
336+ """Load checkpoints from the checkpoint files into a dictionary.
337+
338+ The results are used to pre-populate the memoizer's lookup_table
339+
340+ Kwargs:
341+ - checkpointDirs (list) : List of run folder to use as checkpoints
342+ Eg. ['runinfo/001', 'runinfo/002']
343+
344+ Returns:
345+ - dict containing, hashed -> future mappings
346+ """
347+ if checkpointDirs :
348+ return self ._load_checkpoints (checkpointDirs )
349+ else :
350+ return {}
0 commit comments