1717from typing import (
1818 Any ,
1919 Callable ,
20+ Collection ,
2021 Generic ,
2122 Iterable ,
23+ List ,
2224 Optional ,
2325 Type ,
2426 TypeVar ,
@@ -57,13 +59,56 @@ class _Node:
5759 __slots__ = ["prev_node" , "next_node" , "key" , "value" , "callbacks" ]
5860
5961 def __init__ (
60- self , prev_node , next_node , key , value , callbacks : Optional [set ] = None
62+ self ,
63+ prev_node ,
64+ next_node ,
65+ key ,
66+ value ,
67+ callbacks : Collection [Callable [[], None ]] = (),
6168 ):
6269 self .prev_node = prev_node
6370 self .next_node = next_node
6471 self .key = key
6572 self .value = value
66- self .callbacks = callbacks or set ()
73+
74+ # Set of callbacks to run when the node gets deleted. We store as a list
75+ # rather than a set to keep memory usage down (and since we expect few
76+ # entries per node, the performance of checking for duplication in a
77+ # list vs using a set is negligible).
78+ #
79+ # Note that we store this as an optional list to keep the memory
80+ # footprint down. Storing `None` is free as its a singleton, while empty
81+ # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
82+ # thing and used sets).
83+ self .callbacks = None # type: Optional[List[Callable[[], None]]]
84+
85+ self .add_callbacks (callbacks )
86+
87+ def add_callbacks (self , callbacks : Collection [Callable [[], None ]]) -> None :
88+ """Add to stored list of callbacks, removing duplicates."""
89+
90+ if not callbacks :
91+ return
92+
93+ if not self .callbacks :
94+ self .callbacks = []
95+
96+ for callback in callbacks :
97+ if callback not in self .callbacks :
98+ self .callbacks .append (callback )
99+
100+ def run_and_clear_callbacks (self ) -> None :
101+ """Run all callbacks and clear the stored list of callbacks. Used when
102+ the node is being deleted.
103+ """
104+
105+ if not self .callbacks :
106+ return
107+
108+ for callback in self .callbacks :
109+ callback ()
110+
111+ self .callbacks = None
67112
68113
69114class LruCache (Generic [KT , VT ]):
@@ -177,10 +222,10 @@ def cache_len():
177222
178223 self .len = synchronized (cache_len )
179224
180- def add_node (key , value , callbacks : Optional [ set ] = None ):
225+ def add_node (key , value , callbacks : Collection [ Callable [[], None ]] = () ):
181226 prev_node = list_root
182227 next_node = prev_node .next_node
183- node = _Node (prev_node , next_node , key , value , callbacks or set () )
228+ node = _Node (prev_node , next_node , key , value , callbacks )
184229 prev_node .next_node = node
185230 next_node .prev_node = node
186231 cache [key ] = node
@@ -211,16 +256,15 @@ def delete_node(node):
211256 deleted_len = size_callback (node .value )
212257 cached_cache_len [0 ] -= deleted_len
213258
214- for cb in node .callbacks :
215- cb ()
216- node .callbacks .clear ()
259+ node .run_and_clear_callbacks ()
260+
217261 return deleted_len
218262
219263 @overload
220264 def cache_get (
221265 key : KT ,
222266 default : Literal [None ] = None ,
223- callbacks : Iterable [Callable [[], None ]] = ...,
267+ callbacks : Collection [Callable [[], None ]] = ...,
224268 update_metrics : bool = ...,
225269 ) -> Optional [VT ]:
226270 ...
@@ -229,7 +273,7 @@ def cache_get(
229273 def cache_get (
230274 key : KT ,
231275 default : T ,
232- callbacks : Iterable [Callable [[], None ]] = ...,
276+ callbacks : Collection [Callable [[], None ]] = ...,
233277 update_metrics : bool = ...,
234278 ) -> Union [T , VT ]:
235279 ...
@@ -238,13 +282,13 @@ def cache_get(
238282 def cache_get (
239283 key : KT ,
240284 default : Optional [T ] = None ,
241- callbacks : Iterable [Callable [[], None ]] = (),
285+ callbacks : Collection [Callable [[], None ]] = (),
242286 update_metrics : bool = True ,
243287 ):
244288 node = cache .get (key , None )
245289 if node is not None :
246290 move_node_to_front (node )
247- node .callbacks . update (callbacks )
291+ node .add_callbacks (callbacks )
248292 if update_metrics and metrics :
249293 metrics .inc_hits ()
250294 return node .value
@@ -260,10 +304,8 @@ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
260304 # We sometimes store large objects, e.g. dicts, which cause
261305 # the inequality check to take a long time. So let's only do
262306 # the check if we have some callbacks to call.
263- if node .callbacks and value != node .value :
264- for cb in node .callbacks :
265- cb ()
266- node .callbacks .clear ()
307+ if value != node .value :
308+ node .run_and_clear_callbacks ()
267309
268310 # We don't bother to protect this by value != node.value as
269311 # generally size_callback will be cheap compared with equality
@@ -273,7 +315,7 @@ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
273315 cached_cache_len [0 ] -= size_callback (node .value )
274316 cached_cache_len [0 ] += size_callback (value )
275317
276- node .callbacks . update (callbacks )
318+ node .add_callbacks (callbacks )
277319
278320 move_node_to_front (node )
279321 node .value = value
@@ -326,8 +368,7 @@ def cache_clear() -> None:
326368 list_root .next_node = list_root
327369 list_root .prev_node = list_root
328370 for node in cache .values ():
329- for cb in node .callbacks :
330- cb ()
371+ node .run_and_clear_callbacks ()
331372 cache .clear ()
332373 if size_callback :
333374 cached_cache_len [0 ] = 0
0 commit comments