Skip to content

Commit 238be4e

Browse files
committed
cache: docstrings and typehints
1 parent 9b636db commit 238be4e

File tree

2 files changed

+127
-24
lines changed

2 files changed

+127
-24
lines changed

chebifier/_custom_cache.py

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,20 @@
88

99

1010
class PerSmilesPerModelLRUCache:
11+
"""
12+
A thread-safe, optionally persistent LRU cache for storing
13+
(SMILES, model_name) → result mappings.
14+
"""
15+
1116
def __init__(self, max_size: int = 100, persist_path: str | None = None):
12-
self._cache = OrderedDict()
17+
"""
18+
Initialize the cache.
19+
20+
Args:
21+
max_size (int): Maximum number of items to keep in the cache.
22+
persist_path (str | None): Optional path to persist cache using pickle.
23+
"""
24+
self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict()
1325
self._max_size = max_size
1426
self._lock = threading.Lock()
1527
self._persist_path = persist_path
@@ -21,6 +33,16 @@ def __init__(self, max_size: int = 100, persist_path: str | None = None):
2133
self._load_cache()
2234

2335
def get(self, smiles: str, model_name: str) -> Any | None:
36+
"""
37+
Retrieve value from cache if present, otherwise return None.
38+
39+
Args:
40+
smiles (str): SMILES string key.
41+
model_name (str): Model identifier.
42+
43+
Returns:
44+
Any | None: Cached value or None.
45+
"""
2446
key = (smiles, model_name)
2547
with self._lock:
2648
if key in self._cache:
@@ -32,6 +54,14 @@ def get(self, smiles: str, model_name: str) -> Any | None:
3254
return None
3355

3456
def set(self, smiles: str, model_name: str, value: Any) -> None:
57+
"""
58+
Store value in cache under (smiles, model_name) key.
59+
60+
Args:
61+
smiles (str): SMILES string key.
62+
model_name (str): Model identifier.
63+
value (Any): Value to cache.
64+
"""
3565
assert value is not None, "Value must not be None"
3666
key = (smiles, model_name)
3767
with self._lock:
@@ -42,6 +72,9 @@ def set(self, smiles: str, model_name: str, value: Any) -> None:
4272
self._cache.popitem(last=False)
4373

4474
def clear(self) -> None:
75+
"""
76+
Clear the cache and remove the persistence file if present.
77+
"""
4578
self._save_cache()
4679
with self._lock:
4780
self._cache.clear()
@@ -50,23 +83,38 @@ def clear(self) -> None:
5083
if self._persist_path and os.path.exists(self._persist_path):
5184
os.remove(self._persist_path)
5285

53-
def stats(self) -> dict:
86+
def stats(self) -> dict[str, int]:
87+
"""
88+
Return cache hit/miss statistics.
89+
90+
Returns:
91+
dict[str, int]: Dictionary with 'hits' and 'misses' keys.
92+
"""
5493
return {"hits": self.hits, "misses": self.misses}
5594

5695
def batch_decorator(self, func: Callable) -> Callable:
57-
"""Decorator for class methods that accept a batch of SMILES as a tuple,
58-
and want caching per (smiles, model_name) combination.
96+
"""
97+
Decorator for class methods that accept a batch of SMILES as a list,
98+
and cache predictions per (smiles, model_name) key.
99+
100+
The instance is expected to have a `model_name` attribute.
101+
102+
Args:
103+
func (Callable): The method to decorate.
104+
105+
Returns:
106+
Callable: The wrapped method.
59107
"""
60108

61109
@wraps(func)
62-
def wrapper(instance, smiles_list: list[str]):
110+
def wrapper(instance, smiles_list: list[str]) -> list[Any]:
63111
assert isinstance(smiles_list, list), "smiles_list must be a list."
64112
model_name = getattr(instance, "model_name", None)
65113
assert model_name is not None, "Instance must have a model_name attribute."
66114

67-
results = []
68-
missing_smiles = []
69-
missing_indices = []
115+
results: list[tuple[int, Any]] = []
116+
missing_smiles: list[str] = []
117+
missing_indices: list[int] = []
70118

71119
# First: try to fetch all from cache
72120
for i, smiles in enumerate(smiles_list):
@@ -82,7 +130,8 @@ def wrapper(instance, smiles_list: list[str]):
82130
new_results = func(instance, tuple(missing_smiles))
83131
assert isinstance(
84132
new_results, Iterable
85-
), "Function must return an Iterable."
133+
), "Function must return an Iterable."
134+
86135
# Save to cache and append
87136
for smiles, prediction, missing_idx in zip(
88137
missing_smiles, new_results, missing_indices
@@ -101,21 +150,41 @@ def wrapper(instance, smiles_list: list[str]):
101150

102151
return wrapper
103152

104-
def __len__(self):
153+
def __len__(self) -> int:
154+
"""
155+
Return number of items in the cache.
156+
157+
Returns:
158+
int: Number of entries in the cache.
159+
"""
105160
with self._lock:
106161
return len(self._cache)
107162

108-
def __repr__(self):
163+
def __repr__(self) -> str:
164+
"""
165+
String representation of the underlying cache.
166+
167+
Returns:
168+
str: String version of the OrderedDict.
169+
"""
109170
return self._cache.__repr__()
110171

111-
def save(self):
172+
def save(self) -> None:
173+
"""
174+
Save the cache to disk, if persistence is enabled.
175+
"""
112176
self._save_cache()
113177

114-
def load(self):
178+
def load(self) -> None:
179+
"""
180+
Load the cache from disk, if persistence is enabled.
181+
"""
115182
self._load_cache()
116183

117184
def _save_cache(self) -> None:
118-
"""Serialize the cache to disk."""
185+
"""
186+
Serialize the cache to disk using pickle.
187+
"""
119188
if self._persist_path:
120189
try:
121190
with open(self._persist_path, "wb") as f:
@@ -124,7 +193,9 @@ def _save_cache(self) -> None:
124193
print(f"[Cache Save Error] {e}")
125194

126195
def _load_cache(self) -> None:
127-
"""Load the cache from disk."""
196+
"""
197+
Load the cache from disk, if the file exists and is non-empty.
198+
"""
128199
if (
129200
self._persist_path
130201
and os.path.exists(self._persist_path)

tests/test_cache.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,46 @@
88

99

1010
class DummyPredictor:
11-
def __init__(self, model_name):
11+
def __init__(self, model_name: str):
12+
"""
13+
Dummy predictor for testing cache decorator.
14+
:param model_name: Name of the model instance (used for key separation).
15+
"""
1216
self.model_name = model_name
1317

1418
@g_cache.batch_decorator
15-
def predict(self, smiles_list: tuple[str]):
19+
def predict(self, smiles_list: tuple[str]) -> list[str]:
20+
"""
21+
Dummy predict method to simulate model inference.
22+
Returns list of predictions with predictable format.
23+
"""
1624
# Simple predictable dummy function for tests
1725
return [f"{self.model_name}_P{i}" for i in range(len(smiles_list))]
1826

1927

2028
class TestPerSmilesPerModelLRUCache(unittest.TestCase):
21-
def setUp(self):
29+
def setUp(self) -> None:
30+
"""
31+
Set up a temporary cache file and cache instance before each test.
32+
"""
2233
# Create temp file for persistence tests
2334
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
2435
self.temp_file.close()
2536
self.cache = PerSmilesPerModelLRUCache(
2637
max_size=3, persist_path=self.temp_file.name
2738
)
2839

29-
def tearDown(self):
40+
def tearDown(self) -> None:
41+
"""
42+
Clean up the temporary file after each test.
43+
"""
3044
if os.path.exists(self.temp_file.name):
3145
os.remove(self.temp_file.name)
3246

33-
def test_cache_miss_and_set_get(self):
47+
def test_cache_miss_and_set_get(self) -> None:
48+
"""
49+
Test cache miss on initial get, then set and confirm hit.
50+
"""
3451
# Initially empty
3552
self.assertEqual(len(self.cache), 0)
3653
self.assertIsNone(self.cache.get("CCC", "model1"))
@@ -41,7 +58,10 @@ def test_cache_miss_and_set_get(self):
4158
self.assertEqual(self.cache.hits, 1)
4259
self.assertEqual(self.cache.misses, 1) # One miss from first get
4360

44-
def test_cache_eviction(self):
61+
def test_cache_eviction(self) -> None:
62+
"""
63+
Test LRU eviction when capacity is exceeded.
64+
"""
4565
self.cache.set("a", "m", "v1")
4666
self.cache.set("b", "m", "v2")
4767
self.cache.set("c", "m", "v3")
@@ -52,7 +72,13 @@ def test_cache_eviction(self):
5272
self.assertIsNone(self.cache.get("a", "m")) # 'a' evicted
5373
self.assertIsNotNone(self.cache.get("d", "m")) # 'd' present
5474

55-
def test_batch_decorator_hits_and_misses(self):
75+
def test_batch_decorator_hits_and_misses(self) -> None:
76+
"""
77+
Test decorator behavior on batch prediction:
78+
- first call (all misses)
79+
- second call (mixed hits and misses)
80+
- third call (more hits and misses)
81+
"""
5682
predictor = DummyPredictor("modelA")
5783
predictor2 = DummyPredictor("modelB")
5884

@@ -120,7 +146,10 @@ def test_batch_decorator_hits_and_misses(self):
120146
stats_after_third["misses"], 14
121147
) # additional 3 misses for GGG, HHH, ZZZ
122148

123-
def test_persistence_save_and_load(self):
149+
def test_persistence_save_and_load(self) -> None:
150+
"""
151+
Test that cache is properly saved to disk and reloaded.
152+
"""
124153
# Set some values
125154
self.cache.set("sm1", "modelX", "val1")
126155
self.cache.set("sm2", "modelX", "val2")
@@ -137,7 +166,10 @@ def test_persistence_save_and_load(self):
137166
self.assertEqual(new_cache.get("sm1", "modelX"), "val1")
138167
self.assertEqual(new_cache.get("sm2", "modelX"), "val2")
139168

140-
def test_clear_cache(self):
169+
def test_clear_cache(self) -> None:
170+
"""
171+
Test clearing the cache and removing persisted file.
172+
"""
141173
self.cache.set("x", "m", "v")
142174
self.cache.save()
143175
self.assertTrue(os.path.exists(self.temp_file.name))

0 commit comments

Comments
 (0)