Skip to content

Commit 450d1b0

Browse files
committed
Add diskcache
1 parent 7f509cc commit 450d1b0

File tree

3 files changed

+126
-1
lines changed

3 files changed

+126
-1
lines changed

elements/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
__version__ = '3.19.1'
1+
__version__ = '3.20.0'
22

33
from .agg import Agg
44
from .checkpoint import Checkpoint, Saveable
55
from .config import Config
66
from .counter import Counter
7+
from .diskcache import diskcache
78
from .flags import Flags
89
from .fps import FPS
910
from .logger import Logger

elements/diskcache.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import hashlib
2+
import json
3+
import pickle
4+
5+
from . import path as pathlib
6+
7+
8+
def diskcache(*args, **kwargs):
9+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
10+
return DiskCache(args[0])
11+
else:
12+
return lambda fn: DiskCache(fn, *args, **kwargs)
13+
14+
15+
diskcache.root = '/tmp/elements/diskcache'
16+
diskcache.refresh = False
17+
diskcache.verbose = False
18+
19+
20+
class DiskCache:
21+
22+
def __init__(self, fn, name=None, refresh=None, verbose=None, root=None):
23+
if name is None:
24+
name = f'{pathlib.Path(__file__).stem}-{fn.__name__}'
25+
self.fn = fn
26+
self.folder = pathlib.Path(diskcache.root if root is None else root) / name
27+
self.refresh = diskcache.refresh if refresh is None else refresh
28+
self.verbose = diskcache.verbose if verbose is None else verbose
29+
30+
def __call__(
31+
self,
32+
*args,
33+
_key=None,
34+
_refresh=None,
35+
**kwargs,
36+
):
37+
try:
38+
inputs = [args, kwargs] if _key is None else _key
39+
inputs = json.dumps(inputs, sort_keys=True)
40+
except ValueError as e:
41+
raise ValueError(
42+
'Diskcache requires function arguments to be JSON serializable. ' +
43+
'Alternatively, pass _key=... into the function with a cache key.'
44+
) from e
45+
key = hashlib.sha256(inputs.encode('utf8')).hexdigest()
46+
filename = self.folder / f'{key}.pkl'
47+
refresh = self.refresh if _refresh is None else _refresh
48+
if filename.exists() and not refresh:
49+
self.verbose and print(f'Loading diskcache: {filename}')
50+
data = pickle.loads(filename.read_bytes())
51+
assert data['inputs'] == inputs, ('Hash collision', data, inputs)
52+
return data['output']
53+
else:
54+
self.verbose and print(f'Filling diskcache: {filename}')
55+
output = self.fn(*args, **kwargs)
56+
data = {'inputs': inputs, 'output': output}
57+
filename.parent.mkdir()
58+
filename.write_bytes(pickle.dumps(data))
59+
return output
60+
61+
def clear(self):
62+
self.verbose and print(f'Clearing diskcache: {self.folder}')
63+
for file in self.folder.glob('*.pkl'):
64+
file.remove()

tests/test_diskcache.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import elements
2+
import elements.diskcache
3+
4+
5+
class TestDiskcache:
6+
7+
def test_basic(self, tmpdir):
8+
elements.diskcache.root = tmpdir
9+
elements.diskcache.verbose = True
10+
11+
counter = elements.Counter()
12+
13+
@elements.diskcache
14+
def fn(foo, bar):
15+
counter.increment()
16+
return foo + bar
17+
18+
fn.clear()
19+
20+
assert counter == 0
21+
assert fn(1, 2) == 3
22+
assert counter == 1
23+
assert fn('a', 'b') == 'ab'
24+
assert counter == 2
25+
26+
assert fn(1, 2) == 3
27+
assert fn('a', 'b') == 'ab'
28+
assert counter == 2
29+
30+
fn.clear()
31+
assert fn(1, 2) == 3
32+
assert fn('a', 'b') == 'ab'
33+
assert counter == 4
34+
35+
assert fn(1, 2, _refresh=True) == 3
36+
assert counter == 5
37+
38+
assert fn(1, 2) == 3
39+
assert fn('a', 'b') == 'ab'
40+
assert counter == 5
41+
42+
43+
def test_names(self, tmpdir):
44+
45+
def make():
46+
@elements.diskcache('fn1')
47+
def fn(foo, bar):
48+
return foo + bar
49+
return fn
50+
fn1 = make()
51+
52+
def make():
53+
@elements.diskcache('fn2')
54+
def fn(foo, bar):
55+
return foo - bar
56+
return fn
57+
fn2 = make()
58+
59+
assert fn1(2, 1) == 3
60+
assert fn2(2, 1) == 1

0 commit comments

Comments
 (0)