88
99
1010class PerSmilesPerModelLRUCache :
11+ """
12+ A thread-safe, optionally persistent LRU cache for storing
13+ (SMILES, model_name) → result mappings.
14+ """
15+
1116 def __init__ (self , max_size : int = 100 , persist_path : str | None = None ):
12- self ._cache = OrderedDict ()
17+ """
18+ Initialize the cache.
19+
20+ Args:
21+ max_size (int): Maximum number of items to keep in the cache.
22+ persist_path (str | None): Optional path to persist cache using pickle.
23+ """
24+ self ._cache : OrderedDict [tuple [str , str ], Any ] = OrderedDict ()
1325 self ._max_size = max_size
1426 self ._lock = threading .Lock ()
1527 self ._persist_path = persist_path
@@ -21,6 +33,16 @@ def __init__(self, max_size: int = 100, persist_path: str | None = None):
2133 self ._load_cache ()
2234
2335 def get (self , smiles : str , model_name : str ) -> Any | None :
36+ """
37+ Retrieve value from cache if present, otherwise return None.
38+
39+ Args:
40+ smiles (str): SMILES string key.
41+ model_name (str): Model identifier.
42+
43+ Returns:
44+ Any | None: Cached value or None.
45+ """
2446 key = (smiles , model_name )
2547 with self ._lock :
2648 if key in self ._cache :
@@ -32,6 +54,14 @@ def get(self, smiles: str, model_name: str) -> Any | None:
3254 return None
3355
3456 def set (self , smiles : str , model_name : str , value : Any ) -> None :
57+ """
58+ Store value in cache under (smiles, model_name) key.
59+
60+ Args:
61+ smiles (str): SMILES string key.
62+ model_name (str): Model identifier.
63+ value (Any): Value to cache.
64+ """
3565 assert value is not None , "Value must not be None"
3666 key = (smiles , model_name )
3767 with self ._lock :
@@ -42,6 +72,9 @@ def set(self, smiles: str, model_name: str, value: Any) -> None:
4272 self ._cache .popitem (last = False )
4373
4474 def clear (self ) -> None :
75+ """
76+ Clear the cache and remove the persistence file if present.
77+ """
4578 self ._save_cache ()
4679 with self ._lock :
4780 self ._cache .clear ()
@@ -50,23 +83,38 @@ def clear(self) -> None:
5083 if self ._persist_path and os .path .exists (self ._persist_path ):
5184 os .remove (self ._persist_path )
5285
53- def stats (self ) -> dict :
86+ def stats (self ) -> dict [str , int ]:
87+ """
88+ Return cache hit/miss statistics.
89+
90+ Returns:
91+ dict[str, int]: Dictionary with 'hits' and 'misses' keys.
92+ """
5493 return {"hits" : self .hits , "misses" : self .misses }
5594
5695 def batch_decorator (self , func : Callable ) -> Callable :
57- """Decorator for class methods that accept a batch of SMILES as a tuple,
58- and want caching per (smiles, model_name) combination.
96+ """
97+ Decorator for class methods that accept a batch of SMILES as a list,
98+ and cache predictions per (smiles, model_name) key.
99+
100+ The instance is expected to have a `model_name` attribute.
101+
102+ Args:
103+ func (Callable): The method to decorate.
104+
105+ Returns:
106+ Callable: The wrapped method.
59107 """
60108
61109 @wraps (func )
62- def wrapper (instance , smiles_list : list [str ]):
110+ def wrapper (instance , smiles_list : list [str ]) -> list [ Any ] :
63111 assert isinstance (smiles_list , list ), "smiles_list must be a list."
64112 model_name = getattr (instance , "model_name" , None )
65113 assert model_name is not None , "Instance must have a model_name attribute."
66114
67- results = []
68- missing_smiles = []
69- missing_indices = []
115+ results : list [ tuple [ int , Any ]] = []
116+ missing_smiles : list [ str ] = []
117+ missing_indices : list [ int ] = []
70118
71119 # First: try to fetch all from cache
72120 for i , smiles in enumerate (smiles_list ):
@@ -82,7 +130,8 @@ def wrapper(instance, smiles_list: list[str]):
82130 new_results = func (instance , tuple (missing_smiles ))
83131 assert isinstance (
84132 new_results , Iterable
85- ), "Function must return an Iterable."
133+ ), "Function must return an Iterable."
134+
86135 # Save to cache and append
87136 for smiles , prediction , missing_idx in zip (
88137 missing_smiles , new_results , missing_indices
@@ -101,21 +150,41 @@ def wrapper(instance, smiles_list: list[str]):
101150
102151 return wrapper
103152
104- def __len__ (self ):
153+ def __len__ (self ) -> int :
154+ """
155+ Return number of items in the cache.
156+
157+ Returns:
158+ int: Number of entries in the cache.
159+ """
105160 with self ._lock :
106161 return len (self ._cache )
107162
108- def __repr__ (self ):
163+ def __repr__ (self ) -> str :
164+ """
165+ String representation of the underlying cache.
166+
167+ Returns:
168+ str: String version of the OrderedDict.
169+ """
109170 return self ._cache .__repr__ ()
110171
111- def save (self ):
172+ def save (self ) -> None :
173+ """
174+ Save the cache to disk, if persistence is enabled.
175+ """
112176 self ._save_cache ()
113177
114- def load (self ):
178+ def load (self ) -> None :
179+ """
180+ Load the cache from disk, if persistence is enabled.
181+ """
115182 self ._load_cache ()
116183
117184 def _save_cache (self ) -> None :
118- """Serialize the cache to disk."""
185+ """
186+ Serialize the cache to disk using pickle.
187+ """
119188 if self ._persist_path :
120189 try :
121190 with open (self ._persist_path , "wb" ) as f :
@@ -124,7 +193,9 @@ def _save_cache(self) -> None:
124193 print (f"[Cache Save Error] { e } " )
125194
126195 def _load_cache (self ) -> None :
127- """Load the cache from disk."""
196+ """
197+ Load the cache from disk, if the file exists and is non-empty.
198+ """
128199 if (
129200 self ._persist_path
130201 and os .path .exists (self ._persist_path )
0 commit comments