|
| 1 | +import hashlib |
| 2 | +import json |
| 3 | +import pickle |
| 4 | + |
| 5 | +from . import path as pathlib |
| 6 | + |
| 7 | + |
| 8 | +def diskcache(*args, **kwargs): |
| 9 | + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): |
| 10 | + return DiskCache(args[0]) |
| 11 | + else: |
| 12 | + return lambda fn: DiskCache(fn, *args, **kwargs) |
| 13 | + |
| 14 | + |
| 15 | +diskcache.root = '/tmp/elements/diskcache' |
| 16 | +diskcache.refresh = False |
| 17 | +diskcache.verbose = False |
| 18 | + |
| 19 | + |
| 20 | +class DiskCache: |
| 21 | + |
| 22 | + def __init__(self, fn, name=None, refresh=None, verbose=None, root=None): |
| 23 | + if name is None: |
| 24 | + name = f'{pathlib.Path(__file__).stem}-{fn.__name__}' |
| 25 | + self.fn = fn |
| 26 | + self.folder = pathlib.Path(diskcache.root if root is None else root) / name |
| 27 | + self.refresh = diskcache.refresh if refresh is None else refresh |
| 28 | + self.verbose = diskcache.verbose if verbose is None else verbose |
| 29 | + |
| 30 | + def __call__( |
| 31 | + self, |
| 32 | + *args, |
| 33 | + _key=None, |
| 34 | + _refresh=None, |
| 35 | + **kwargs, |
| 36 | + ): |
| 37 | + try: |
| 38 | + inputs = [args, kwargs] if _key is None else _key |
| 39 | + inputs = json.dumps(inputs, sort_keys=True) |
| 40 | + except ValueError as e: |
| 41 | + raise ValueError( |
| 42 | + 'Diskcache requires function arguments to be JSON serializable. ' + |
| 43 | + 'Alternatively, pass _key=... into the function with a cache key.' |
| 44 | + ) from e |
| 45 | + key = hashlib.sha256(inputs.encode('utf8')).hexdigest() |
| 46 | + filename = self.folder / f'{key}.pkl' |
| 47 | + refresh = self.refresh if _refresh is None else _refresh |
| 48 | + if filename.exists() and not refresh: |
| 49 | + self.verbose and print(f'Loading diskcache: {filename}') |
| 50 | + data = pickle.loads(filename.read_bytes()) |
| 51 | + assert data['inputs'] == inputs, ('Hash collision', data, inputs) |
| 52 | + return data['output'] |
| 53 | + else: |
| 54 | + self.verbose and print(f'Filling diskcache: {filename}') |
| 55 | + output = self.fn(*args, **kwargs) |
| 56 | + data = {'inputs': inputs, 'output': output} |
| 57 | + filename.parent.mkdir() |
| 58 | + filename.write_bytes(pickle.dumps(data)) |
| 59 | + return output |
| 60 | + |
| 61 | + def clear(self): |
| 62 | + self.verbose and print(f'Clearing diskcache: {self.folder}') |
| 63 | + for file in self.folder.glob('*.pkl'): |
| 64 | + file.remove() |
0 commit comments