1010import warnings
1111from pathlib import Path
1212from typing import List , Optional , Tuple
13+ import multiprocessing as mp
1314
1415import numpy as np
1516import networkx as nx
@@ -269,9 +270,27 @@ def get_signs(
269270 signs [i ] = np .prod (weights [path ], axis = 0 )
270271 return signs
271272
273+ # Validate idx_path_matrix has sufficient elements for slicing
274+ if len (idx_path_matrix ) < 2 :
275+ # Fall back to dense method for edge cases
276+ signs = np .zeros_like (weights )
277+ for i , path in enumerate (path_to_node ):
278+ signs [i ] = np .prod (weights [path ], axis = 0 )
279+ return signs
280+
272281 # Using sparse matrix reduction
273282 data = path_matrix .multiply (weights ).data
274- return np .multiply .reduceat (data , idx_path_matrix [:- 1 ])
283+ indices = idx_path_matrix [:- 1 ]
284+
285+ # Additional safety check for empty data or indices
286+ if len (data ) == 0 or len (indices ) == 0 :
287+ # Fall back to dense method
288+ signs = np .zeros_like (weights )
289+ for i , path in enumerate (path_to_node ):
290+ signs [i ] = np .prod (weights [path ], axis = 0 )
291+ return signs
292+
293+ return np .multiply .reduceat (data , indices )
275294
276295
277296def majority_voting (votes : np .ndarray ) -> np .ndarray :
@@ -296,6 +315,42 @@ def majority_voting(votes: np.ndarray) -> np.ndarray:
296315 return result
297316
298317
318+ def _generate_single_tree_worker (args : Tuple ) -> Tuple [int , np .ndarray ]:
319+ """
320+ Worker function to generate a single tree and compute signs.
321+
322+ This function is designed to be used with multiprocessing to parallelize
323+ tree generation across multiple CPU cores.
324+
325+ Args:
326+ args: Tuple containing (num_qubits, samples, tree_index, base_seed)
327+
328+ Returns:
329+ Tuple of (tree_index, signs) where signs is the computed sign array
330+ """
331+ num_qubits , samples , tree_index , base_seed = args
332+
333+ # Set unique random seed for this worker to ensure reproducible but different results
334+ worker_seed = base_seed + tree_index * 1000 # Large offset to avoid seed collisions
335+ fix_random_seed (worker_seed )
336+
337+ # Step 1: generate random spanning tree
338+ tree , spanning = generate_hypercube_tree (num_qubits )
339+
340+ # Step 2: find global roots and leaves
341+ roots , leafs = find_global_roots_and_leafs (tree , num_qubits )
342+
343+ # Step 3: convert to matrix form for parallel sign computation
344+ paths = get_path (tree , num_qubits )
345+ pmatrix = get_path_sparse_matrix (paths , num_qubits )
346+ idx_cumsum = np .insert (np .cumsum (pmatrix .getnnz (axis = 1 )), 0 , 0 )
347+
348+ # Step 4: compute weights and signs
349+ weights = get_weight (samples , roots , leafs , num_qubits )
350+ signs = get_signs (weights , pmatrix , paths , idx_cumsum )
351+
352+ return tree_index , signs
353+
299354
300355def generate_random_forest (
301356 num_qubits : int ,
@@ -330,70 +385,105 @@ def generate_random_forest(
330385 show_first = False
331386
332387
333- signs_stack : Optional [np .ndarray ] = None
334-
335- # Prepare output directory if needed
336- if save_tree :
337- base_dir = Path ("forest gallery" ) / f"{ num_qubits } -qubit"
338- base_dir .mkdir (parents = True , exist_ok = True )
339-
340- for m in range (num_trees ):
341- # Step 1: generate random spanning tree
342- tree , spanning = generate_hypercube_tree (num_qubits )
343-
344- # Step 2: find global roots and leaves
345- roots , leafs = find_global_roots_and_leafs (tree , num_qubits )
346-
347- # Step 3: convert to matrix form for parallel sign computation
348- paths = get_path (tree , num_qubits )
349- pmatrix = get_path_sparse_matrix (paths , num_qubits )
350- idx_cumsum = np .insert (np .cumsum (pmatrix .getnnz (axis = 1 )), 0 , 0 )
351-
352- # Step 4: compute weights and signs
353- weights = get_weight (samples , roots , leafs , num_qubits )
354- signs = get_signs (weights , pmatrix , paths , idx_cumsum )
355-
356- # Optional: save first 5 tree visualizations
357- if save_tree and m < 5 :
358- try :
359- G = nx .hypercube_graph (num_qubits )
360- G = nx .convert_node_labels_to_integers (G )
361- pos = nx .drawing .nx_agraph .graphviz_layout (G , prog = "dot" )
362-
363- # Dynamically size the figure
364- base_size = 6
365- extra = max (0 , num_qubits - 5 )
366- width_factor = 2 ** extra
367- height_factor = 1.5 ** extra
368- plt .figure (figsize = (base_size * width_factor , base_size * height_factor ))
369-
370- nx .draw_networkx_edges (G , pos , edge_color = 'tab:gray' , alpha = 0.2 , width = 2 )
371- nx .draw_networkx_edges (spanning , pos , edge_color = 'tab:gray' , width = 3 )
372- node_colors = ['tab:blue' if s == 1 else 'tab:orange' for s in signs ]
373- nx .draw_networkx_nodes (G , pos , node_color = node_colors , node_size = 400 , edgecolors = 'black' )
374- nx .draw_networkx_labels (G , pos , font_color = "white" )
375-
376- plt .axis ('off' )
377- plt .tight_layout ()
378- fig_path = base_dir / f"tree_{ m } .png"
379- plt .savefig (fig_path , bbox_inches = 'tight' , pad_inches = 0 , transparent = True , dpi = 200 )
380-
381- if show_tree and m == 0 :
382- # this will pop up the first tree in-line (or in a window)
383- plt .show ()
384- finally :
385- # Always close the figure to prevent memory leaks
386- plt .close ()
387-
388- # Accumulate for majority voting
389- if signs_stack is None :
390- signs_stack = signs
391- else :
392- signs_stack = np .vstack ([signs_stack , signs ])
393-
394- assert signs_stack is not None
395- # Ensure signs_stack is 2D for majority_voting
396- if signs_stack .ndim == 1 :
397- signs_stack = signs_stack .reshape (1 , - 1 )
388+ # Pre-allocate signs array for all trees to avoid expensive np.vstack operations
389+ N = 2 ** num_qubits
390+ signs_stack = np .zeros ((num_trees , N ), dtype = float )
391+
392+ # Determine if we should use parallel processing
393+ # Use parallel processing for larger num_trees, but avoid when visualization is needed
394+ # or when multiprocessing overhead would exceed benefits
395+ USE_PARALLEL_THRESHOLD = 4
396+ use_parallel = (
397+ num_trees >= USE_PARALLEL_THRESHOLD and
398+ not (save_tree or show_tree ) and # Visualization complicates multiprocessing
399+ mp .cpu_count () > 1 # Only if multiple cores available
400+ )
401+
402+ # Generate base seed for reproducible results across workers
403+ base_seed = random .randint (0 , 2 ** 31 - 1 )
404+
405+ if use_parallel :
406+ # Parallel processing path
407+ logging .info (f"Using parallel processing with { mp .cpu_count ()} cores for { num_trees } trees" )
408+
409+ # Prepare arguments for worker processes
410+ worker_args = [
411+ (num_qubits , samples , tree_index , base_seed )
412+ for tree_index in range (num_trees )
413+ ]
414+
415+ # Use multiprocessing to generate trees in parallel
416+ with mp .Pool () as pool :
417+ results = pool .map (_generate_single_tree_worker , worker_args )
418+
419+ # Collect results in correct order
420+ for tree_index , signs in results :
421+ signs_stack [tree_index ] = signs
422+
423+ else :
424+ # Sequential processing path (original implementation)
425+ # Prepare output directory if needed
426+ if save_tree :
427+ base_dir = Path ("forest gallery" ) / f"{ num_qubits } -qubit"
428+ base_dir .mkdir (parents = True , exist_ok = True )
429+
430+ for m in range (num_trees ):
431+ # Set deterministic seed for this tree
432+ tree_seed = base_seed + m * 1000
433+ fix_random_seed (tree_seed )
434+
435+ # Step 1: generate random spanning tree
436+ tree , spanning = generate_hypercube_tree (num_qubits )
437+
438+ # Step 2: find global roots and leaves
439+ roots , leafs = find_global_roots_and_leafs (tree , num_qubits )
440+
441+ # Step 3: convert to matrix form for parallel sign computation
442+ paths = get_path (tree , num_qubits )
443+ pmatrix = get_path_sparse_matrix (paths , num_qubits )
444+ idx_cumsum = np .insert (np .cumsum (pmatrix .getnnz (axis = 1 )), 0 , 0 )
445+
446+ # Step 4: compute weights and signs
447+ weights = get_weight (samples , roots , leafs , num_qubits )
448+ signs = get_signs (weights , pmatrix , paths , idx_cumsum )
449+
450+ # Optional: save first 5 tree visualizations
451+ if save_tree and m < 5 :
452+ current_fig = None
453+ try :
454+ G = nx .hypercube_graph (num_qubits )
455+ G = nx .convert_node_labels_to_integers (G )
456+ pos = nx .drawing .nx_agraph .graphviz_layout (G , prog = "dot" )
457+
458+ # Dynamically size the figure
459+ base_size = 6
460+ extra = max (0 , num_qubits - 5 )
461+ width_factor = 2 ** extra
462+ height_factor = 1.5 ** extra
463+ current_fig = plt .figure (figsize = (base_size * width_factor , base_size * height_factor ))
464+
465+ nx .draw_networkx_edges (G , pos , edge_color = 'tab:gray' , alpha = 0.2 , width = 2 )
466+ nx .draw_networkx_edges (spanning , pos , edge_color = 'tab:gray' , width = 3 )
467+ node_colors = ['tab:blue' if s == 1 else 'tab:orange' for s in signs ]
468+ nx .draw_networkx_nodes (G , pos , node_color = node_colors , node_size = 400 , edgecolors = 'black' )
469+ nx .draw_networkx_labels (G , pos , font_color = "white" )
470+
471+ plt .axis ('off' )
472+ plt .tight_layout ()
473+ fig_path = base_dir / f"tree_{ m } .png"
474+ plt .savefig (fig_path , bbox_inches = 'tight' , pad_inches = 0 , transparent = True , dpi = 200 )
475+
476+ if show_tree and m == 0 :
477+ # this will pop up the first tree in-line (or in a window)
478+ plt .show ()
479+ finally :
480+ # Always close the figure to prevent memory leaks, but only if it was created
481+ if current_fig is not None :
482+ plt .close (current_fig )
483+
484+ # Store signs for this tree in pre-allocated array
485+ signs_stack [m ] = signs
486+
487+ # signs_stack is already 2D with shape (num_trees, 2**num_qubits)
398488 return majority_voting (signs_stack )
399489
0 commit comments