Skip to content

Commit 3555f40

Browse files
committed
Add cache as a memoshelve decorator, use dill
1 parent 69b4e6f commit 3555f40

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

gbmi/utils/memoshelve.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import os
22
import shelve
33
from contextlib import contextmanager
4+
from functools import wraps
45
from pathlib import Path
56
from typing import Any, Callable, Dict, Optional, Union
67

8+
from dill import Pickler, Unpickler
9+
710
from gbmi.utils import backup as backup_file
811
from gbmi.utils.hashing import get_hash_ascii
912

13+
# monkeypatch shelve as per https://stackoverflow.com/q/52927236/377022
14+
shelve.Pickler = Pickler
15+
shelve.Unpickler = Unpickler
16+
1017
memoshelve_cache: Dict[str, Dict[str, Any]] = {}
1118

1219

@@ -95,3 +102,36 @@ def uncache(
95102
key = get_hash((args, kwargs))
96103
if key in db:
97104
del db[key]
105+
106+
107+
# for decorators
108+
def cache(
109+
filename: Path | str | None = None,
110+
cache: Dict[str, Dict[str, Any]] = memoshelve_cache,
111+
get_hash: Callable = get_hash_ascii,
112+
get_hash_mem: Optional[Callable] = None,
113+
print_cache_miss: bool = False,
114+
disable: bool = False,
115+
):
116+
def wrap(value: Callable):
117+
path = Path(filename or f".cache/{value.__name__}.shelve")
118+
path.parent.mkdir(parents=True, exist_ok=True)
119+
120+
@wraps(value)
121+
def wrapper(*args, **kwargs):
122+
if disable:
123+
return value(*args, **kwargs)
124+
else:
125+
with memoshelve(
126+
value,
127+
filename=path,
128+
cache=cache,
129+
get_hash=get_hash,
130+
get_hash_mem=get_hash_mem,
131+
print_cache_miss=print_cache_miss,
132+
)() as f:
133+
return f(*args, **kwargs)
134+
135+
return wrapper
136+
137+
return wrap

0 commit comments

Comments
 (0)