44
55from chebifier import PerSmilesPerModelLRUCache
66
7- g_cache = PerSmilesPerModelLRUCache (max_size = 3 )
7+ g_cache = PerSmilesPerModelLRUCache (max_size = 100 , persist_path = None )
88
99
1010class DummyPredictor :
@@ -14,7 +14,7 @@ def __init__(self, model_name):
1414 @g_cache .batch_decorator
1515 def predict (self , smiles_list : tuple [str ]):
1616 # Simple predictable dummy function for tests
17- return [f"{ self .model_name } { i } " for i in range (len (smiles_list ))]
17+ return [f"{ self .model_name } _P { i } " for i in range (len (smiles_list ))]
1818
1919
2020class TestPerSmilesPerModelLRUCache (unittest .TestCase ):
@@ -73,30 +73,52 @@ def test_batch_decorator_hits_and_misses(self):
7373 ["modelB_P0" , "modelB_P1" , "modelB_P2" , "modelB_P3" , "modelB_P4" ],
7474 )
7575 stats_after_first = g_cache .stats ()
76- self .assertEqual (stats_after_first ["misses" ], 3 )
76+ self .assertEqual (
77+ stats_after_first ["misses" ], 10
78+ ) # 5 for modelA and 5 for modelB
79+ self .assertEqual (stats_after_first ["hits" ], 0 )
80+ self .assertEqual (len (g_cache ), 10 ) # 5 for each model
81+
82+ # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2",
83+ # ("DDD", "modelA"): "modelA_P3", ("EEE", "modelA"): "modelA_P4",
84+ # ("AAA", "modelB"): "modelB_P0", ("BBB", "modelB"): "modelB_P1", ("CCC", "modelB"): "modelB_P2",}
85+ # ("DDD", "modelB"): "modelB_P3", ("EEE", "modelB"): "modelB_P4"}
7786
78- # cache = {("AAA", "modelA"): "modelA_P0", ("BBB", "modelA"): "modelA_P1", ("CCC", "modelA"): "modelA_P2"}
7987 # Second call with some hits and some misses
8088 results2 = predictor .predict (["FFF" , "DDD" ])
81- # AAA from cache
82- # FFF is not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function
83- # and dummy predictor returns iterates over the smiles list and return P{idx} corresponding to the index
84- self .assertListEqual (results2 , ["P3 " , "P0 " ])
89+ # DDD from cache
90+ # FFF is not in cache, so its predicted, hence it has P0 as its the only one passed to prediction function
91+ # and dummy predictor iterates over the smiles list and returns P{idx} corresponding to the index
92+ self .assertListEqual (results2 , ["modelA_P0 " , "modelA_P3 " ])
8593 stats_after_second = g_cache .stats ()
86- self .assertEqual (stats_after_second ["hits" ], 1 )
87- self .assertEqual (stats_after_second ["misses" ], 4 )
94+ self .assertEqual (stats_after_second ["hits" ], 1 ) # additional 1 hit for DDD
95+ self .assertEqual (stats_after_second ["misses" ], 11 ) # 1 miss for FFF
8896
89- # cache = {("AAA", "modelA"): "P0 ", ("BBB", "modelA"): "P1 ", ("CCC", "modelA"): "P2 ",
90- # ("DDD", "modelA"): "P3 ", ("EEE", "modelA"): "P4 ", ("FFF", "modelA"): "P0" }
97+ # cache = {("AAA", "modelA"): "modelA_P0 ", ("BBB", "modelA"): "modelA_P1 ", ("CCC", "modelA"): "modelA_P2 ",
98+ # ("DDD", "modelA"): "modelA_P3 ", ("EEE", "modelA"): "modelA_P4 ", ("FFF", "modelA"): "modelA_P0", ... }
9199
92100 # Third call with some hits and some misses
93101 results3 = predictor .predict (["EEE" , "GGG" , "DDD" , "HHH" , "BBB" , "ZZZ" ])
94102 # Here, predictions for [EEE, DDD, BBB] are retrived from cache,
95103 # while [GGG, HHH, ZZZ] are not in cache and hence passe to the prediction function
96- self .assertListEqual (results3 , ["P4" , "P0" , "P3" , "P0" , "P1" , "P0" ])
104+ self .assertListEqual (
105+ results3 ,
106+ [
107+ "modelA_P4" , # EEE from cache
108+ "modelA_P0" , # GGG not in cache, so it predicted, hence it has P0 as its the only one passed to prediction function
109+ "modelA_P3" , # DDD from cache
110+ "modelA_P1" , # HHH not in cache, so it predicted, hence it has P1 as its the only one passed to prediction function
111+ "modelA_P1" , # BBB from cache
112+ "modelA_P2" , # ZZZ not in cache, so it predicted, hence it has P2 as its the only one passed to prediction function
113+ ],
114+ )
97115 stats_after_third = g_cache .stats ()
98- self .assertEqual (stats_after_third ["hits" ], 1 )
99- self .assertEqual (stats_after_third ["misses" ], 4 )
116+ self .assertEqual (
117+ stats_after_third ["hits" ], 4
118+ ) # additional 3 hits for EEE, DDD, BBB
119+ self .assertEqual (
120+ stats_after_third ["misses" ], 14
121+ ) # additional 3 misses for GGG, HHH, ZZZ
100122
101123 def test_persistence_save_and_load (self ):
102124 # Set some values
0 commit comments