1313import multiprocessing as mp
1414import functools
1515import atexit
16+ from collections import OrderedDict
1617
1718import numpy as np
1819import networkx as nx
2122from scipy .sparse import coo_matrix
2223import matplotlib .pyplot as plt
2324
24- # Global cache for hypercube graphs to avoid recreation
25- _HYPERCUBE_CACHE : Dict [int , nx .Graph ] = {}
25+
26+ class LRUCache :
27+ """Least Recently Used cache with size limit."""
28+
29+ def __init__ (self , max_size : int ):
30+ self .cache = OrderedDict ()
31+ self .max_size = max_size
32+ self .hits = 0
33+ self .misses = 0
34+
35+ def get (self , key ):
36+ """Get item from cache, updating access order."""
37+ if key in self .cache :
38+ self .hits += 1
39+ # Move to end (most recently used)
40+ self .cache .move_to_end (key )
41+ return self .cache [key ]
42+ self .misses += 1
43+ return None
44+
45+ def set (self , key , value ):
46+ """Set item in cache, evicting LRU if needed."""
47+ if key in self .cache :
48+ # Update and move to end
49+ self .cache .move_to_end (key )
50+ self .cache [key ] = value
51+ # Evict least recently used if over limit
52+ if len (self .cache ) > self .max_size :
53+ evicted = self .cache .popitem (last = False )
54+ logging .debug (f"Evicted cache entry: { evicted [0 ]} " )
55+
56+ def clear (self ):
57+ """Clear cache and reset statistics."""
58+ self .cache .clear ()
59+ self .hits = 0
60+ self .misses = 0
61+
62+ def __contains__ (self , key ):
63+ return key in self .cache
64+
65+ def __len__ (self ):
66+ return len (self .cache )
67+
68+ def info (self ):
69+ """Get cache statistics."""
70+ return {
71+ 'size' : len (self .cache ),
72+ 'max_size' : self .max_size ,
73+ 'hits' : self .hits ,
74+ 'misses' : self .misses ,
75+ 'hit_rate' : self .hits / (self .hits + self .misses ) if (self .hits + self .misses ) > 0 else 0
76+ }
77+
78+
79+ # Cache size limits based on memory footprint
80+ MAX_HYPERCUBE_CACHE_SIZE = 16 # Large graphs (increased from 10)
81+ MAX_POWER2_CACHE_SIZE = 100 # Small lists
82+ MAX_HAMMING_CACHE_SIZE = 20 # Medium dictionaries
83+
84+ # Global LRU caches with size limits
85+ _HYPERCUBE_CACHE = LRUCache (MAX_HYPERCUBE_CACHE_SIZE )
86+ _POWER2_NODES_CACHE = LRUCache (MAX_POWER2_CACHE_SIZE )
87+ _HAMMING_LAYERS_CACHE = LRUCache (MAX_HAMMING_CACHE_SIZE )
2688
2789# Global persistent worker pool
2890_GLOBAL_POOL : Optional [mp .Pool ] = None
2991_POOL_SIZE : Optional [int ] = None
3092
31- # Cache for pre-computed values
32- _POWER2_NODES_CACHE : Dict [int , List [int ]] = {}
33- _HAMMING_LAYERS_CACHE : Dict [int , Dict [int , List [int ]]] = {}
34-
3593
3694def fix_random_seed (seed : int ) -> None :
3795 """
@@ -112,6 +170,7 @@ def cleanup_pool() -> None:
112170def get_cached_hypercube (dimension : int ) -> nx .Graph :
113171 """
114172 Get a cached hypercube graph or create and cache a new one.
173+ Uses LRU eviction when cache is full.
115174
116175 Args:
117176 dimension: Dimension of the hypercube.
@@ -121,17 +180,22 @@ def get_cached_hypercube(dimension: int) -> nx.Graph:
121180 """
122181 global _HYPERCUBE_CACHE
123182
124- if dimension not in _HYPERCUBE_CACHE :
125- G = nx .hypercube_graph (dimension )
126- G = nx .convert_node_labels_to_integers (G )
127- _HYPERCUBE_CACHE [dimension ] = G
183+ cached = _HYPERCUBE_CACHE .get (dimension )
184+ if cached is not None :
185+ return cached
186+
187+ # Create new graph
188+ G = nx .hypercube_graph (dimension )
189+ G = nx .convert_node_labels_to_integers (G )
190+ _HYPERCUBE_CACHE .set (dimension , G )
128191
129- return _HYPERCUBE_CACHE [ dimension ]
192+ return G
130193
131194
132195def get_cached_power2_nodes (dimension : int ) -> List [int ]:
133196 """
134197 Get cached power-of-2 nodes for a given dimension.
198+ Uses LRU eviction when cache is full.
135199
136200 Args:
137201 dimension: Number of qubits.
@@ -141,19 +205,25 @@ def get_cached_power2_nodes(dimension: int) -> List[int]:
141205 """
142206 global _POWER2_NODES_CACHE
143207
144- if dimension not in _POWER2_NODES_CACHE :
145- N = 2 ** dimension
146- _POWER2_NODES_CACHE [dimension ] = [
147- node for node in range (N )
148- if node > 0 and (node & (node - 1 )) == 0
149- ]
208+ cached = _POWER2_NODES_CACHE .get (dimension )
209+ if cached is not None :
210+ return cached
211+
212+ # Calculate power-of-2 nodes
213+ N = 2 ** dimension
214+ nodes = [
215+ node for node in range (N )
216+ if node > 0 and (node & (node - 1 )) == 0
217+ ]
218+ _POWER2_NODES_CACHE .set (dimension , nodes )
150219
151- return _POWER2_NODES_CACHE [ dimension ]
220+ return nodes
152221
153222
154223def get_cached_hamming_layers (dimension : int ) -> Dict [int , List [int ]]:
155224 """
156225 Get cached Hamming weight layers for a given dimension.
226+ Uses LRU eviction when cache is full.
157227
158228 Args:
159229 dimension: Number of qubits.
@@ -163,14 +233,84 @@ def get_cached_hamming_layers(dimension: int) -> Dict[int, List[int]]:
163233 """
164234 global _HAMMING_LAYERS_CACHE
165235
166- if dimension not in _HAMMING_LAYERS_CACHE :
167- N = 2 ** dimension
168- layers = {}
169- for k in range (dimension + 1 ):
170- layers [k ] = [node for node in range (N ) if hamming_weight (node ) == k ]
171- _HAMMING_LAYERS_CACHE [dimension ] = layers
236+ cached = _HAMMING_LAYERS_CACHE .get (dimension )
237+ if cached is not None :
238+ return cached
239+
240+ # Calculate Hamming layers
241+ N = 2 ** dimension
242+ layers = {}
243+ for k in range (dimension + 1 ):
244+ layers [k ] = [node for node in range (N ) if hamming_weight (node ) == k ]
245+ _HAMMING_LAYERS_CACHE .set (dimension , layers )
246+
247+ return layers
248+
249+
250+ def clear_caches () -> None :
251+ """Clear all global caches to free memory."""
252+ global _HYPERCUBE_CACHE , _POWER2_NODES_CACHE , _HAMMING_LAYERS_CACHE
253+
254+ _HYPERCUBE_CACHE .clear ()
255+ _POWER2_NODES_CACHE .clear ()
256+ _HAMMING_LAYERS_CACHE .clear ()
257+
258+ logging .info ("All caches cleared" )
259+
260+
261+ def get_cache_info () -> Dict [str , Dict ]:
262+ """
263+ Get information about all caches.
264+
265+ Returns:
266+ Dictionary with cache statistics for each cache.
267+ """
268+ return {
269+ 'hypercube' : _HYPERCUBE_CACHE .info (),
270+ 'power2_nodes' : _POWER2_NODES_CACHE .info (),
271+ 'hamming_layers' : _HAMMING_LAYERS_CACHE .info ()
272+ }
273+
274+
275+ def set_cache_sizes (
276+ hypercube : Optional [int ] = None ,
277+ power2 : Optional [int ] = None ,
278+ hamming : Optional [int ] = None
279+ ) -> None :
280+ """
281+ Adjust cache size limits. Clears excess entries if sizes are reduced.
282+
283+ Args:
284+ hypercube: New size limit for hypercube cache
285+ power2: New size limit for power2 nodes cache
286+ hamming: New size limit for Hamming layers cache
287+ """
288+ global _HYPERCUBE_CACHE , _POWER2_NODES_CACHE , _HAMMING_LAYERS_CACHE
289+
290+ if hypercube is not None :
291+ old_size = _HYPERCUBE_CACHE .max_size
292+ _HYPERCUBE_CACHE .max_size = hypercube
293+ if hypercube < old_size :
294+ # Evict excess entries
295+ while len (_HYPERCUBE_CACHE ) > hypercube :
296+ _HYPERCUBE_CACHE .cache .popitem (last = False )
297+ logging .debug (f"Reduced hypercube cache size to { hypercube } " )
298+
299+ if power2 is not None :
300+ old_size = _POWER2_NODES_CACHE .max_size
301+ _POWER2_NODES_CACHE .max_size = power2
302+ if power2 < old_size :
303+ while len (_POWER2_NODES_CACHE ) > power2 :
304+ _POWER2_NODES_CACHE .cache .popitem (last = False )
305+ logging .debug (f"Reduced power2 cache size to { power2 } " )
172306
173- return _HAMMING_LAYERS_CACHE [dimension ]
307+ if hamming is not None :
308+ old_size = _HAMMING_LAYERS_CACHE .max_size
309+ _HAMMING_LAYERS_CACHE .max_size = hamming
310+ if hamming < old_size :
311+ while len (_HAMMING_LAYERS_CACHE ) > hamming :
312+ _HAMMING_LAYERS_CACHE .cache .popitem (last = False )
313+ logging .debug (f"Reduced hamming cache size to { hamming } " )
174314
175315
176316def optimized_uniform_spanning_tree (G : nx .Graph , dimension : int ) -> nx .Graph :
@@ -628,6 +768,7 @@ def generate_random_forest(
628768
629769 # Optional: save first 5 tree visualizations
630770 if save_tree and m < 5 :
771+ # Initialize current_fig before try block to ensure cleanup in all cases
631772 current_fig = None
632773 try :
633774 G = nx .hypercube_graph (num_qubits )
@@ -655,10 +796,16 @@ def generate_random_forest(
655796 if show_tree and m == 0 :
656797 # this will pop up the first tree in-line (or in a window)
657798 plt .show ()
799+ except Exception as e :
800+ # Log the error for debugging but don't stop the process
801+ logging .warning (f"Failed to visualize tree { m } : { e } " )
658802 finally :
659- # Always close the figure to prevent memory leaks, but only if it was created
803+ # Always close the figure to prevent memory leaks
804+ # This also closes any figures created by plt functions even if current_fig wasn't assigned
660805 if current_fig is not None :
661806 plt .close (current_fig )
807+ # Extra safety: close all figures to prevent any potential leaks
808+ plt .close ('all' )
662809
663810 # Store signs for this tree in pre-allocated array
664811 signs_stack [m ] = signs
0 commit comments