22
33import hashlib
44import logging
5+ import os
56import pickle
67from functools import lru_cache , singledispatch
7- from typing import TYPE_CHECKING , Any , Dict , List , Optional
8+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Sequence
89
10+ import typeguard
11+
12+ from parsl .dataflow .errors import BadCheckpoint
913from parsl .dataflow .taskrecord import TaskRecord
1014
1115if TYPE_CHECKING :
@@ -146,7 +150,7 @@ class Memoizer:
146150
147151 """
148152
149- def __init__ (self , dfk : DataFlowKernel , memoize : bool = True , checkpoint : Dict [str , Future [ Any ]] = {} ):
153+ def __init__ (self , dfk : DataFlowKernel , * , memoize : bool = True , checkpoint_files : Sequence [str ] ):
150154 """Initialize the memoizer.
151155
152156 Args:
@@ -159,6 +163,10 @@ def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[s
159163 self .dfk = dfk
160164 self .memoize = memoize
161165
166+ # TODO: we always load checkpoints even if we then discard them...
167+ # this is more obvious here, less obvious in previous Parsl...
168+ checkpoint = self .load_checkpoints (checkpoint_files )
169+
162170 if self .memoize :
163171 logger .info ("App caching initialized" )
164172 self .memo_lookup_table = checkpoint
@@ -274,3 +282,71 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
274282 else :
275283 logger .debug (f"Storing app cache entry { task ['hashsum' ]} with result from task { task_id } " )
276284 self .memo_lookup_table [task ['hashsum' ]] = r
285+
286+ def _load_checkpoints (self , checkpointDirs : Sequence [str ]) -> Dict [str , Future [Any ]]:
287+ """Load a checkpoint file into a lookup table.
288+
289+ The data being loaded from the pickle file mostly contains input
290+ attributes of the task: func, args, kwargs, env...
291+ To simplify the check of whether the exact task has been completed
292+ in the checkpoint, we hash these input params and use it as the key
293+ for the memoized lookup table.
294+
295+ Args:
296+ - checkpointDirs (list) : List of filepaths to checkpoints
297+ Eg. ['runinfo/001', 'runinfo/002']
298+
299+ Returns:
300+ - memoized_lookup_table (dict)
301+ """
302+ memo_lookup_table = {}
303+
304+ for checkpoint_dir in checkpointDirs :
305+ logger .info ("Loading checkpoints from {}" .format (checkpoint_dir ))
306+ checkpoint_file = os .path .join (checkpoint_dir , 'tasks.pkl' )
307+ try :
308+ with open (checkpoint_file , 'rb' ) as f :
309+ while True :
310+ try :
311+ data = pickle .load (f )
312+ # Copy and hash only the input attributes
313+ memo_fu : Future = Future ()
314+ assert data ['exception' ] is None
315+ memo_fu .set_result (data ['result' ])
316+ memo_lookup_table [data ['hash' ]] = memo_fu
317+
318+ except EOFError :
319+ # Done with the checkpoint file
320+ break
321+ except FileNotFoundError :
322+ reason = "Checkpoint file was not found: {}" .format (
323+ checkpoint_file )
324+ logger .error (reason )
325+ raise BadCheckpoint (reason )
326+ except Exception :
327+ reason = "Failed to load checkpoint: {}" .format (
328+ checkpoint_file )
329+ logger .error (reason )
330+ raise BadCheckpoint (reason )
331+
332+ logger .info ("Completed loading checkpoint: {0} with {1} tasks" .format (checkpoint_file ,
333+ len (memo_lookup_table .keys ())))
334+ return memo_lookup_table
335+
336+ @typeguard .typechecked
337+ def load_checkpoints (self , checkpointDirs : Optional [Sequence [str ]]) -> Dict [str , Future ]:
338+ """Load checkpoints from the checkpoint files into a dictionary.
339+
340+ The results are used to pre-populate the memoizer's lookup_table
341+
342+ Kwargs:
343+ - checkpointDirs (list) : List of run folder to use as checkpoints
344+ Eg. ['runinfo/001', 'runinfo/002']
345+
346+ Returns:
347+ - dict containing, hashed -> future mappings
348+ """
349+ if checkpointDirs :
350+ return self ._load_checkpoints (checkpointDirs )
351+ else :
352+ return {}
0 commit comments