@@ -374,6 +374,191 @@ def _numba_threshold_affinities(affinities, thresh):
374374 return result
375375
376376
377+ @njit (parallel = True )
378+ def _numba_build_csr_from_neighbors_scalar_bw (
379+ row_neighbors , row_distances , bandwidth , decay , thresh , n_rows , n_cols
380+ ):
381+ """
382+ Build CSR in a single pass; threshold before allocation; avoid COO conversions for lower peak memory.
383+ Optimized version for scalar bandwidth using numba acceleration.
384+
385+ Two-pass approach:
386+ - Pass 1: compute weights per row, apply thresh to make a boolean mask, and count kept edges
387+ - Pass 2: fill each row's slice with the kept neighbors and weights
388+ """
389+ # Use float64 for numerical precision compatibility
390+ bandwidth_f64 = np .float64 (bandwidth )
391+ decay_f64 = np .float64 (decay )
392+ thresh_f64 = np .float64 (thresh )
393+
394+ # Pass 1: count valid entries per row (parallel)
395+ row_kept_counts = np .zeros (n_rows , dtype = np .int64 )
396+
397+ for i in prange (n_rows ):
398+ count = 0
399+ for j in range (len (row_distances [i ])):
400+ if j < len (row_neighbors [i ]) and row_neighbors [i ][j ] >= 0 :
401+ # Compute weight
402+ scaled = np .float64 (row_distances [i ][j ]) / bandwidth_f64
403+ powered = scaled ** decay_f64
404+ affinity = np .exp (- powered )
405+
406+ # Handle edge cases
407+ if np .isnan (affinity ) or np .isinf (affinity ):
408+ affinity = np .float64 (1.0 )
409+
410+ if affinity >= thresh_f64 :
411+ count += 1
412+
413+ row_kept_counts [i ] = count
414+
415+ # Compute indptr
416+ indptr = np .empty (n_rows + 1 , dtype = np .int64 )
417+ indptr [0 ] = 0
418+ for i in range (n_rows ):
419+ indptr [i + 1 ] = indptr [i ] + row_kept_counts [i ]
420+
421+ nnz = int (indptr [- 1 ])
422+
423+ # Allocate output arrays
424+ indices = np .empty (nnz , dtype = np .int32 )
425+ data = np .empty (nnz , dtype = np .float64 )
426+
427+ # Pass 2: fill arrays (parallel by row)
428+ for i in prange (n_rows ):
429+ start_pos = indptr [i ]
430+ write_pos = start_pos
431+
432+ for j in range (len (row_distances [i ])):
433+ if j < len (row_neighbors [i ]) and row_neighbors [i ][j ] >= 0 :
434+ # Recompute weight (same as counting pass)
435+ scaled = np .float64 (row_distances [i ][j ]) / bandwidth_f64
436+ powered = scaled ** decay_f64
437+ affinity = np .exp (- powered )
438+
439+ if np .isnan (affinity ) or np .isinf (affinity ):
440+ affinity = np .float64 (1.0 )
441+
442+ if affinity >= thresh_f64 :
443+ indices [write_pos ] = row_neighbors [i ][j ]
444+ data [write_pos ] = affinity
445+ write_pos += 1
446+
447+ return data , indices , indptr
448+
449+
450+ def _build_csr_from_neighbors (row_neighbors , row_distances , bandwidth , decay , thresh , shape ):
451+ """
452+ Build CSR in a single pass; threshold before allocation; avoid COO conversions for lower peak memory.
453+
454+ Parameters
455+ ----------
456+ row_neighbors : list of arrays
457+ Per-row neighbor indices
458+ row_distances : list of arrays
459+ Per-row distances to neighbors
460+ bandwidth : float or array
461+ Bandwidth parameter(s)
462+ decay : float
463+ Decay parameter
464+ thresh : float
465+ Threshold for keeping edges
466+ shape : tuple
467+ Shape of output matrix (n_rows, n_cols)
468+
469+ Returns
470+ -------
471+ csr_matrix : scipy.sparse.csr_matrix
472+ Constructed CSR matrix with thresholding applied
473+ """
474+ n_rows , n_cols = shape
475+
476+ # Handle scalar bandwidth with numba optimization
477+ if isinstance (bandwidth , numbers .Number ) and NUMBA_AVAILABLE :
478+ # Convert lists to arrays for numba
479+ max_neighbors = max (len (neighbors ) for neighbors in row_neighbors )
480+ neighbors_array = np .full ((n_rows , max_neighbors ), - 1 , dtype = np .int32 )
481+ distances_array = np .full ((n_rows , max_neighbors ), np .inf , dtype = np .float64 )
482+
483+ for i in range (n_rows ):
484+ n_neighbors = len (row_neighbors [i ])
485+ neighbors_array [i , :n_neighbors ] = row_neighbors [i ]
486+ distances_array [i , :n_neighbors ] = row_distances [i ]
487+
488+ data , indices , indptr = _numba_build_csr_from_neighbors_scalar_bw (
489+ neighbors_array , distances_array , bandwidth , decay , thresh , n_rows , n_cols
490+ )
491+ else :
492+ # Fallback implementation for variable bandwidth or no numba
493+ # Pass 1: compute weights and count kept edges
494+ row_masks = []
495+ row_kept_counts = np .empty (n_rows , dtype = np .int64 )
496+
497+ for i in range (n_rows ):
498+ distances_i = np .array (row_distances [i ], dtype = np .float64 )
499+ if isinstance (bandwidth , numbers .Number ):
500+ bw = bandwidth
501+ else :
502+ bw = bandwidth [i ]
503+
504+ # Compute weights
505+ scaled = distances_i / bw
506+ weights = np .exp (- np .power (scaled , decay ))
507+ weights = np .where (np .isnan (weights ), 1.0 , weights )
508+
509+ # Apply threshold
510+ mask = weights >= thresh
511+ row_masks .append (mask )
512+ row_kept_counts [i ] = int (np .count_nonzero (mask ))
513+
514+ # Compute indptr
515+ indptr = np .empty (n_rows + 1 , dtype = np .int64 )
516+ indptr [0 ] = 0
517+ np .cumsum (row_kept_counts , out = indptr [1 :])
518+ nnz = int (indptr [- 1 ])
519+
520+ # Allocate output arrays
521+ indices = np .empty (nnz , dtype = np .int32 )
522+ data = np .empty (nnz , dtype = np .float64 )
523+
524+ # Pass 2: fill arrays
525+ for i in range (n_rows ):
526+ start , end = indptr [i ], indptr [i + 1 ]
527+ if start == end :
528+ continue
529+
530+ mask = row_masks [i ]
531+ neighbors_i = np .array (row_neighbors [i ])[mask ]
532+ distances_i = np .array (row_distances [i ], dtype = np .float64 )[mask ]
533+
534+ if isinstance (bandwidth , numbers .Number ):
535+ bw = bandwidth
536+ else :
537+ bw = bandwidth [i ]
538+
539+ # Recompute weights for kept edges
540+ scaled = distances_i / bw
541+ weights_i = np .exp (- np .power (scaled , decay ))
542+ weights_i = np .where (np .isnan (weights_i ), 1.0 , weights_i )
543+
544+ # Optional: sort by column index for canonical order
545+ if len (neighbors_i ) > 1 :
546+ sort_order = np .argsort (neighbors_i )
547+ neighbors_i = neighbors_i [sort_order ]
548+ weights_i = weights_i [sort_order ]
549+
550+ indices [start :end ] = neighbors_i
551+ data [start :end ] = weights_i
552+
553+ # Build CSR matrix
554+ K = sparse .csr_matrix ((data , indices , indptr ), shape = shape )
555+
556+ # Handle potential duplicates
557+ K .sum_duplicates ()
558+
559+ return K
560+
561+
377562class kNNGraph (DataGraph ):
378563 """
379564 K nearest neighbors graph
@@ -789,50 +974,11 @@ def build_kernel_to_data(
789974 for i , idx in enumerate (update_idx ):
790975 distances [idx ] = dist_new [i ]
791976 indices [idx ] = ind_new [i ]
792- # Scale distances and compute affinities
793- if isinstance (bandwidth , numbers .Number ):
794- distances_flat = np .concatenate (distances )
795- if NUMBA_AVAILABLE :
796- data = _numba_scale_distances_single_bandwidth (
797- distances_flat , bandwidth
798- )
799- data = _numba_compute_affinities (data , self .decay )
800- data = _numba_threshold_affinities (data , self .thresh )
801- else :
802- data = distances_flat / bandwidth
803- data = np .exp (- 1 * np .power (data , self .decay ))
804- data = np .where (np .isnan (data ), 1 , data )
805- data [data < self .thresh ] = 0
806- else :
807- if NUMBA_AVAILABLE :
808- # For variable bandwidth, we need to handle the scaling differently
809- data = []
810- for i in range (len (distances )):
811- scaled = _numba_scale_distances_single_bandwidth (
812- distances [i ], bandwidth [i ]
813- )
814- affinities = _numba_compute_affinities (scaled , self .decay )
815- thresholded = _numba_threshold_affinities (
816- affinities , self .thresh
817- )
818- data .append (thresholded )
819- data = np .concatenate (data )
820- else :
821- data = np .concatenate (
822- [distances [i ] / bandwidth [i ] for i in range (len (distances ))]
823- )
824- data = np .exp (- 1 * np .power (data , self .decay ))
825- data = np .where (np .isnan (data ), 1 , data )
826- data [data < self .thresh ] = 0
827-
828- indices = np .concatenate (indices )
829- indptr = np .concatenate ([[0 ], np .cumsum ([len (d ) for d in distances ])])
830- K = sparse .csr_matrix (
831- (data , indices , indptr ), shape = (Y .shape [0 ], self .data_nu .shape [0 ])
977+ # Use optimized CSR construction to avoid COO conversions
978+ K = _build_csr_from_neighbors (
979+ indices , distances , bandwidth , self .decay , self .thresh ,
980+ (Y .shape [0 ], self .data_nu .shape [0 ])
832981 )
833- K = K .tocoo ()
834- K .eliminate_zeros ()
835- K = K .tocsr ()
836982 return K
837983
838984
@@ -1457,9 +1603,8 @@ def build_kernel(self):
14571603 ):
14581604 K = K .tocsr ()
14591605 K .data [K .data < self .thresh ] = 0
1460- K = K . tocoo ()
1606+ # Eliminate zeros directly on CSR - avoid unnecessary COO conversion
14611607 K .eliminate_zeros ()
1462- K = K .tocsr ()
14631608 else :
14641609 K [K < self .thresh ] = 0
14651610 return K
0 commit comments