|
1 | 1 | import os
|
2 | 2 | import shelve
|
3 | 3 | from contextlib import contextmanager
|
| 4 | +from functools import wraps |
4 | 5 | from pathlib import Path
|
5 | 6 | from typing import Any, Callable, Dict, Optional, Union
|
6 | 7 |
|
| 8 | +from dill import Pickler, Unpickler |
| 9 | + |
7 | 10 | from gbmi.utils import backup as backup_file
|
8 | 11 | from gbmi.utils.hashing import get_hash_ascii
|
9 | 12 |
|
| 13 | +# monkeypatch shelve as per https://stackoverflow.com/q/52927236/377022 |
| 14 | +shelve.Pickler = Pickler |
| 15 | +shelve.Unpickler = Unpickler |
| 16 | + |
10 | 17 | memoshelve_cache: Dict[str, Dict[str, Any]] = {}
|
11 | 18 |
|
12 | 19 |
|
@@ -95,3 +102,36 @@ def uncache(
|
95 | 102 | key = get_hash((args, kwargs))
|
96 | 103 | if key in db:
|
97 | 104 | del db[key]
|
| 105 | + |
| 106 | + |
| 107 | +# for decorators |
| 108 | +def cache( |
| 109 | + filename: Path | str | None = None, |
| 110 | + cache: Dict[str, Dict[str, Any]] = memoshelve_cache, |
| 111 | + get_hash: Callable = get_hash_ascii, |
| 112 | + get_hash_mem: Optional[Callable] = None, |
| 113 | + print_cache_miss: bool = False, |
| 114 | + disable: bool = False, |
| 115 | +): |
| 116 | + def wrap(value: Callable): |
| 117 | + path = Path(filename or f".cache/{value.__name__}.shelve") |
| 118 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 119 | + |
| 120 | + @wraps(value) |
| 121 | + def wrapper(*args, **kwargs): |
| 122 | + if disable: |
| 123 | + return value(*args, **kwargs) |
| 124 | + else: |
| 125 | + with memoshelve( |
| 126 | + value, |
| 127 | + filename=path, |
| 128 | + cache=cache, |
| 129 | + get_hash=get_hash, |
| 130 | + get_hash_mem=get_hash_mem, |
| 131 | + print_cache_miss=print_cache_miss, |
| 132 | + )() as f: |
| 133 | + return f(*args, **kwargs) |
| 134 | + |
| 135 | + return wrapper |
| 136 | + |
| 137 | + return wrap |
0 commit comments