Skip to content

Commit 9b636db

Browse files
committed
refine test for realistic scenario with 2 models
1 parent cdd7de9 commit 9b636db

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

chebifier/_custom_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ def _save_cache(self) -> None:
125125

126126
def _load_cache(self) -> None:
127127
"""Load the cache from disk."""
128-
if os.path.exists(self._persist_path):
128+
if (
129+
self._persist_path
130+
and os.path.exists(self._persist_path)
131+
and os.path.getsize(self._persist_path) > 0
132+
):
129133
try:
130134
with open(self._persist_path, "rb") as f:
131135
loaded = pickle.load(f)

tests/test_cache.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from chebifier import PerSmilesPerModelLRUCache
66

7-
g_cache = PerSmilesPerModelLRUCache(max_size=3)
7+
g_cache = PerSmilesPerModelLRUCache(max_size=100, persist_path=None)
88

99

1010
class DummyPredictor:
@@ -14,7 +14,7 @@ def __init__(self, model_name):
1414
@g_cache.batch_decorator
1515
def predict(self, smiles_list: tuple[str]):
1616
# Simple predictable dummy function for tests
17-
return [f"{self.model_name}{i}" for i in range(len(smiles_list))]
17+
return [f"{self.model_name}_P{i}" for i in range(len(smiles_list))]
1818

1919

2020
class TestPerSmilesPerModelLRUCache(unittest.TestCase):
@@ -73,30 +73,52 @@ def test_batch_decorator_hits_and_misses(self):
7373
["modelB_P0", "modelB_P1", "modelB_P2", "modelB_P3", "modelB_P4"],
7474
)
7575
stats_after_first = g_cache.stats()
76-
self.assertEqual(stats_after_first["misses"], 3)
76+
self.assertEqual(
77+
stats_after_first["misses"], 10
78+
) # 5 for modelA and 5 for modelB
79+
self.assertEqual(stats_after_first["hits"], 0)
80+
self.assertEqual(len(g_cache), 10) # 5 for each model
81+
82+
# cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2",
83+
# ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4",
84+
# ("AAA", "modelB"): "modelB_P0", ("BBB", "modelB"): "modelB_P1", ("CCC", "modelB"): "modelB_P2",}
85+
# ("DDD", "modelB"): "modelB_P3", ("EEE", "modelB"): "modelB_P4"}
7786

78-
# cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2"}
7987
# Second call with some hits and some misses
8088
results2 = predictor.predict(["FFF", "DDD"])
81-
# AAA from cache
82-
# FFF is not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function
83-
# and dummy predictor returns iterates over the smiles list and return P{idx} corresponding to the index
84-
self.assertListEqual(results2, ["P3", "P0"])
89+
# DDD from cache
90+
# FFF is not in cache, so its predicted, hence it has P0 as its the only one passed to prediction function
91+
# and dummy predictor iterates over the smiles list and returns P{idx} corresponding to the index
92+
self.assertListEqual(results2, ["modelA_P0", "modelA_P3"])
8593
stats_after_second = g_cache.stats()
86-
self.assertEqual(stats_after_second["hits"], 1)
87-
self.assertEqual(stats_after_second["misses"], 4)
94+
self.assertEqual(stats_after_second["hits"], 1) # additional 1 hit for DDD
95+
self.assertEqual(stats_after_second["misses"], 11) # 1 miss for FFF
8896

89-
# cache = {("AAA", "modelA"): "P0", ("BBB", "modelA"): "P1", ("CCC", "modelA"): "P2",
90-
# ("DDD", "modelA"): "P3", ("EEE", "modelA"): "P4", ("FFF", "modelA"): "P0"}
97+
# cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2",
98+
# ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4", ("FFF", "modelA"): "modelA_P0", ...}
9199

92100
# Third call with some hits and some misses
93101
results3 = predictor.predict(["EEE", "GGG", "DDD", "HHH", "BBB", "ZZZ"])
94102
# Here, predictions for [EEE, DDD, BBB] are retrived from cache,
95103
# while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function
96-
self.assertListEqual(results3, ["P4", "P0", "P3", "P0", "P1", "P0"])
104+
self.assertListEqual(
105+
results3,
106+
[
107+
"modelA_P4", # EEE from cache
108+
"modelA_P0", # GGG not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function
109+
"modelA_P3", # DDD from cache
110+
"modelA_P1", # HHH not in cache, so it predicted, hence it has P1 as its the only one passed to prediction function
111+
"modelA_P1", # BBB from cache
112+
"modelA_P2", # ZZZ not in cache, so it predicted, hence it has P2 as its the only one passed to prediction function
113+
],
114+
)
97115
stats_after_third = g_cache.stats()
98-
self.assertEqual(stats_after_third["hits"], 1)
99-
self.assertEqual(stats_after_third["misses"], 4)
116+
self.assertEqual(
117+
stats_after_third["hits"], 4
118+
) # additional 3 hits for EEE, DDD, BBB
119+
self.assertEqual(
120+
stats_after_third["misses"], 14
121+
) # additional 3 misses for GGG, HHH, ZZZ
100122

101123
def test_persistence_save_and_load(self):
102124
# Set some values

0 commit comments

Comments
 (0)