Skip to content

Commit 9561c9a

Browse files
committed
Update tests to cover cached graph generation
1 parent 6fbbbc2 commit 9561c9a

File tree

4 files changed

+389
-53
lines changed

4 files changed

+389
-53
lines changed

hadamard_random_forest/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010

1111
# Package Metadata
1212
__title__ = "hadamard_random_forest"
13-
__version__ = "0.1.0"
13+
__version__ = "0.2.0"
1414
__license__ = "MIT"
1515

1616
# Core functionality from main module
1717
from .random_forest import (
1818
fix_random_seed,
1919
optimized_uniform_spanning_tree,
2020
generate_hypercube_tree,
21-
generate_random_forest
21+
generate_random_forest,
22+
# Cache management
23+
clear_caches,
24+
get_cache_info,
25+
set_cache_sizes
2226
)
2327

2428
from .sample import (
@@ -50,6 +54,10 @@
5054
"get_samples_noisy",
5155
"get_samples_hardware",
5256
"get_statevector",
57+
# cache management
58+
"clear_caches",
59+
"get_cache_info",
60+
"set_cache_sizes",
5361
# utilities
5462
"random_statevector",
5563
"logarithmic_negativity",

hadamard_random_forest/random_forest.py

Lines changed: 173 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import multiprocessing as mp
1414
import functools
1515
import atexit
16+
from collections import OrderedDict
1617

1718
import numpy as np
1819
import networkx as nx
@@ -21,17 +22,74 @@
2122
from scipy.sparse import coo_matrix
2223
import matplotlib.pyplot as plt
2324

24-
# Global cache for hypercube graphs to avoid recreation
25-
_HYPERCUBE_CACHE: Dict[int, nx.Graph] = {}
25+
26+
class LRUCache:
27+
"""Least Recently Used cache with size limit."""
28+
29+
def __init__(self, max_size: int):
30+
self.cache = OrderedDict()
31+
self.max_size = max_size
32+
self.hits = 0
33+
self.misses = 0
34+
35+
def get(self, key):
36+
"""Get item from cache, updating access order."""
37+
if key in self.cache:
38+
self.hits += 1
39+
# Move to end (most recently used)
40+
self.cache.move_to_end(key)
41+
return self.cache[key]
42+
self.misses += 1
43+
return None
44+
45+
def set(self, key, value):
46+
"""Set item in cache, evicting LRU if needed."""
47+
if key in self.cache:
48+
# Update and move to end
49+
self.cache.move_to_end(key)
50+
self.cache[key] = value
51+
# Evict least recently used if over limit
52+
if len(self.cache) > self.max_size:
53+
evicted = self.cache.popitem(last=False)
54+
logging.debug(f"Evicted cache entry: {evicted[0]}")
55+
56+
def clear(self):
57+
"""Clear cache and reset statistics."""
58+
self.cache.clear()
59+
self.hits = 0
60+
self.misses = 0
61+
62+
def __contains__(self, key):
63+
return key in self.cache
64+
65+
def __len__(self):
66+
return len(self.cache)
67+
68+
def info(self):
69+
"""Get cache statistics."""
70+
return {
71+
'size': len(self.cache),
72+
'max_size': self.max_size,
73+
'hits': self.hits,
74+
'misses': self.misses,
75+
'hit_rate': self.hits / (self.hits + self.misses) if (self.hits + self.misses) > 0 else 0
76+
}
77+
78+
79+
# Cache size limits based on memory footprint
80+
MAX_HYPERCUBE_CACHE_SIZE = 16 # Large graphs (increased from 10)
81+
MAX_POWER2_CACHE_SIZE = 100 # Small lists
82+
MAX_HAMMING_CACHE_SIZE = 20 # Medium dictionaries
83+
84+
# Global LRU caches with size limits
85+
_HYPERCUBE_CACHE = LRUCache(MAX_HYPERCUBE_CACHE_SIZE)
86+
_POWER2_NODES_CACHE = LRUCache(MAX_POWER2_CACHE_SIZE)
87+
_HAMMING_LAYERS_CACHE = LRUCache(MAX_HAMMING_CACHE_SIZE)
2688

2789
# Global persistent worker pool
2890
_GLOBAL_POOL: Optional[mp.Pool] = None
2991
_POOL_SIZE: Optional[int] = None
3092

31-
# Cache for pre-computed values
32-
_POWER2_NODES_CACHE: Dict[int, List[int]] = {}
33-
_HAMMING_LAYERS_CACHE: Dict[int, Dict[int, List[int]]] = {}
34-
3593

3694
def fix_random_seed(seed: int) -> None:
3795
"""
@@ -112,6 +170,7 @@ def cleanup_pool() -> None:
112170
def get_cached_hypercube(dimension: int) -> nx.Graph:
113171
"""
114172
Get a cached hypercube graph or create and cache a new one.
173+
Uses LRU eviction when cache is full.
115174
116175
Args:
117176
dimension: Dimension of the hypercube.
@@ -121,17 +180,22 @@ def get_cached_hypercube(dimension: int) -> nx.Graph:
121180
"""
122181
global _HYPERCUBE_CACHE
123182

124-
if dimension not in _HYPERCUBE_CACHE:
125-
G = nx.hypercube_graph(dimension)
126-
G = nx.convert_node_labels_to_integers(G)
127-
_HYPERCUBE_CACHE[dimension] = G
183+
cached = _HYPERCUBE_CACHE.get(dimension)
184+
if cached is not None:
185+
return cached
186+
187+
# Create new graph
188+
G = nx.hypercube_graph(dimension)
189+
G = nx.convert_node_labels_to_integers(G)
190+
_HYPERCUBE_CACHE.set(dimension, G)
128191

129-
return _HYPERCUBE_CACHE[dimension]
192+
return G
130193

131194

132195
def get_cached_power2_nodes(dimension: int) -> List[int]:
133196
"""
134197
Get cached power-of-2 nodes for a given dimension.
198+
Uses LRU eviction when cache is full.
135199
136200
Args:
137201
dimension: Number of qubits.
@@ -141,19 +205,25 @@ def get_cached_power2_nodes(dimension: int) -> List[int]:
141205
"""
142206
global _POWER2_NODES_CACHE
143207

144-
if dimension not in _POWER2_NODES_CACHE:
145-
N = 2**dimension
146-
_POWER2_NODES_CACHE[dimension] = [
147-
node for node in range(N)
148-
if node > 0 and (node & (node - 1)) == 0
149-
]
208+
cached = _POWER2_NODES_CACHE.get(dimension)
209+
if cached is not None:
210+
return cached
211+
212+
# Calculate power-of-2 nodes
213+
N = 2**dimension
214+
nodes = [
215+
node for node in range(N)
216+
if node > 0 and (node & (node - 1)) == 0
217+
]
218+
_POWER2_NODES_CACHE.set(dimension, nodes)
150219

151-
return _POWER2_NODES_CACHE[dimension]
220+
return nodes
152221

153222

154223
def get_cached_hamming_layers(dimension: int) -> Dict[int, List[int]]:
155224
"""
156225
Get cached Hamming weight layers for a given dimension.
226+
Uses LRU eviction when cache is full.
157227
158228
Args:
159229
dimension: Number of qubits.
@@ -163,14 +233,84 @@ def get_cached_hamming_layers(dimension: int) -> Dict[int, List[int]]:
163233
"""
164234
global _HAMMING_LAYERS_CACHE
165235

166-
if dimension not in _HAMMING_LAYERS_CACHE:
167-
N = 2**dimension
168-
layers = {}
169-
for k in range(dimension + 1):
170-
layers[k] = [node for node in range(N) if hamming_weight(node) == k]
171-
_HAMMING_LAYERS_CACHE[dimension] = layers
236+
cached = _HAMMING_LAYERS_CACHE.get(dimension)
237+
if cached is not None:
238+
return cached
239+
240+
# Calculate Hamming layers
241+
N = 2**dimension
242+
layers = {}
243+
for k in range(dimension + 1):
244+
layers[k] = [node for node in range(N) if hamming_weight(node) == k]
245+
_HAMMING_LAYERS_CACHE.set(dimension, layers)
246+
247+
return layers
248+
249+
250+
def clear_caches() -> None:
251+
"""Clear all global caches to free memory."""
252+
global _HYPERCUBE_CACHE, _POWER2_NODES_CACHE, _HAMMING_LAYERS_CACHE
253+
254+
_HYPERCUBE_CACHE.clear()
255+
_POWER2_NODES_CACHE.clear()
256+
_HAMMING_LAYERS_CACHE.clear()
257+
258+
logging.info("All caches cleared")
259+
260+
261+
def get_cache_info() -> Dict[str, Dict]:
262+
"""
263+
Get information about all caches.
264+
265+
Returns:
266+
Dictionary with cache statistics for each cache.
267+
"""
268+
return {
269+
'hypercube': _HYPERCUBE_CACHE.info(),
270+
'power2_nodes': _POWER2_NODES_CACHE.info(),
271+
'hamming_layers': _HAMMING_LAYERS_CACHE.info()
272+
}
273+
274+
275+
def set_cache_sizes(
276+
hypercube: Optional[int] = None,
277+
power2: Optional[int] = None,
278+
hamming: Optional[int] = None
279+
) -> None:
280+
"""
281+
Adjust cache size limits. Clears excess entries if sizes are reduced.
282+
283+
Args:
284+
hypercube: New size limit for hypercube cache
285+
power2: New size limit for power2 nodes cache
286+
hamming: New size limit for Hamming layers cache
287+
"""
288+
global _HYPERCUBE_CACHE, _POWER2_NODES_CACHE, _HAMMING_LAYERS_CACHE
289+
290+
if hypercube is not None:
291+
old_size = _HYPERCUBE_CACHE.max_size
292+
_HYPERCUBE_CACHE.max_size = hypercube
293+
if hypercube < old_size:
294+
# Evict excess entries
295+
while len(_HYPERCUBE_CACHE) > hypercube:
296+
_HYPERCUBE_CACHE.cache.popitem(last=False)
297+
logging.debug(f"Reduced hypercube cache size to {hypercube}")
298+
299+
if power2 is not None:
300+
old_size = _POWER2_NODES_CACHE.max_size
301+
_POWER2_NODES_CACHE.max_size = power2
302+
if power2 < old_size:
303+
while len(_POWER2_NODES_CACHE) > power2:
304+
_POWER2_NODES_CACHE.cache.popitem(last=False)
305+
logging.debug(f"Reduced power2 cache size to {power2}")
172306

173-
return _HAMMING_LAYERS_CACHE[dimension]
307+
if hamming is not None:
308+
old_size = _HAMMING_LAYERS_CACHE.max_size
309+
_HAMMING_LAYERS_CACHE.max_size = hamming
310+
if hamming < old_size:
311+
while len(_HAMMING_LAYERS_CACHE) > hamming:
312+
_HAMMING_LAYERS_CACHE.cache.popitem(last=False)
313+
logging.debug(f"Reduced hamming cache size to {hamming}")
174314

175315

176316
def optimized_uniform_spanning_tree(G: nx.Graph, dimension: int) -> nx.Graph:
@@ -628,6 +768,7 @@ def generate_random_forest(
628768

629769
# Optional: save first 5 tree visualizations
630770
if save_tree and m < 5:
771+
# Initialize current_fig before try block to ensure cleanup in all cases
631772
current_fig = None
632773
try:
633774
G = nx.hypercube_graph(num_qubits)
@@ -655,10 +796,16 @@ def generate_random_forest(
655796
if show_tree and m == 0:
656797
# this will pop up the first tree in-line (or in a window)
657798
plt.show()
799+
except Exception as e:
800+
# Log the error for debugging but don't stop the process
801+
logging.warning(f"Failed to visualize tree {m}: {e}")
658802
finally:
659-
# Always close the figure to prevent memory leaks, but only if it was created
803+
# Always close the figure to prevent memory leaks
804+
# This also closes any figures created by plt functions even if current_fig wasn't assigned
660805
if current_fig is not None:
661806
plt.close(current_fig)
807+
# Extra safety: close all figures to prevent any potential leaks
808+
plt.close('all')
662809

663810
# Store signs for this tree in pre-allocated array
664811
signs_stack[m] = signs

0 commit comments

Comments
 (0)