Skip to content

Commit 5bd6c24

Browse files
authored
Merge pull request #79 from MattScicluna/remove_csr_coo_csr_conversion
Remove csr coo conversion
2 parents c18d494 + dcd3b83 commit 5bd6c24

File tree

2 files changed

+465
-45
lines changed

2 files changed

+465
-45
lines changed

graphtools/graphs.py

Lines changed: 190 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,191 @@ def _numba_threshold_affinities(affinities, thresh):
374374
return result
375375

376376

377+
@njit(parallel=True)
378+
def _numba_build_csr_from_neighbors_scalar_bw(
379+
row_neighbors, row_distances, bandwidth, decay, thresh, n_rows, n_cols
380+
):
381+
"""
382+
Build CSR in a single pass; threshold before allocation; avoid COO conversions for lower peak memory.
383+
Optimized version for scalar bandwidth using numba acceleration.
384+
385+
Two-pass approach:
386+
- Pass 1: compute weights per row, apply thresh to make a boolean mask, and count kept edges
387+
- Pass 2: fill each row's slice with the kept neighbors and weights
388+
"""
389+
# Use float64 for numerical precision compatibility
390+
bandwidth_f64 = np.float64(bandwidth)
391+
decay_f64 = np.float64(decay)
392+
thresh_f64 = np.float64(thresh)
393+
394+
# Pass 1: count valid entries per row (parallel)
395+
row_kept_counts = np.zeros(n_rows, dtype=np.int64)
396+
397+
for i in prange(n_rows):
398+
count = 0
399+
for j in range(len(row_distances[i])):
400+
if j < len(row_neighbors[i]) and row_neighbors[i][j] >= 0:
401+
# Compute weight
402+
scaled = np.float64(row_distances[i][j]) / bandwidth_f64
403+
powered = scaled**decay_f64
404+
affinity = np.exp(-powered)
405+
406+
# Handle edge cases
407+
if np.isnan(affinity) or np.isinf(affinity):
408+
affinity = np.float64(1.0)
409+
410+
if affinity >= thresh_f64:
411+
count += 1
412+
413+
row_kept_counts[i] = count
414+
415+
# Compute indptr
416+
indptr = np.empty(n_rows + 1, dtype=np.int64)
417+
indptr[0] = 0
418+
for i in range(n_rows):
419+
indptr[i + 1] = indptr[i] + row_kept_counts[i]
420+
421+
nnz = int(indptr[-1])
422+
423+
# Allocate output arrays
424+
indices = np.empty(nnz, dtype=np.int32)
425+
data = np.empty(nnz, dtype=np.float64)
426+
427+
# Pass 2: fill arrays (parallel by row)
428+
for i in prange(n_rows):
429+
start_pos = indptr[i]
430+
write_pos = start_pos
431+
432+
for j in range(len(row_distances[i])):
433+
if j < len(row_neighbors[i]) and row_neighbors[i][j] >= 0:
434+
# Recompute weight (same as counting pass)
435+
scaled = np.float64(row_distances[i][j]) / bandwidth_f64
436+
powered = scaled**decay_f64
437+
affinity = np.exp(-powered)
438+
439+
if np.isnan(affinity) or np.isinf(affinity):
440+
affinity = np.float64(1.0)
441+
442+
if affinity >= thresh_f64:
443+
indices[write_pos] = row_neighbors[i][j]
444+
data[write_pos] = affinity
445+
write_pos += 1
446+
447+
return data, indices, indptr
448+
449+
450+
def _build_csr_from_neighbors(row_neighbors, row_distances, bandwidth, decay, thresh, shape):
451+
"""
452+
Build CSR in a single pass; threshold before allocation; avoid COO conversions for lower peak memory.
453+
454+
Parameters
455+
----------
456+
row_neighbors : list of arrays
457+
Per-row neighbor indices
458+
row_distances : list of arrays
459+
Per-row distances to neighbors
460+
bandwidth : float or array
461+
Bandwidth parameter(s)
462+
decay : float
463+
Decay parameter
464+
thresh : float
465+
Threshold for keeping edges
466+
shape : tuple
467+
Shape of output matrix (n_rows, n_cols)
468+
469+
Returns
470+
-------
471+
csr_matrix : scipy.sparse.csr_matrix
472+
Constructed CSR matrix with thresholding applied
473+
"""
474+
n_rows, n_cols = shape
475+
476+
# Handle scalar bandwidth with numba optimization
477+
if isinstance(bandwidth, numbers.Number) and NUMBA_AVAILABLE:
478+
# Convert lists to arrays for numba
479+
max_neighbors = max(len(neighbors) for neighbors in row_neighbors)
480+
neighbors_array = np.full((n_rows, max_neighbors), -1, dtype=np.int32)
481+
distances_array = np.full((n_rows, max_neighbors), np.inf, dtype=np.float64)
482+
483+
for i in range(n_rows):
484+
n_neighbors = len(row_neighbors[i])
485+
neighbors_array[i, :n_neighbors] = row_neighbors[i]
486+
distances_array[i, :n_neighbors] = row_distances[i]
487+
488+
data, indices, indptr = _numba_build_csr_from_neighbors_scalar_bw(
489+
neighbors_array, distances_array, bandwidth, decay, thresh, n_rows, n_cols
490+
)
491+
else:
492+
# Fallback implementation for variable bandwidth or no numba
493+
# Pass 1: compute weights and count kept edges
494+
row_masks = []
495+
row_kept_counts = np.empty(n_rows, dtype=np.int64)
496+
497+
for i in range(n_rows):
498+
distances_i = np.array(row_distances[i], dtype=np.float64)
499+
if isinstance(bandwidth, numbers.Number):
500+
bw = bandwidth
501+
else:
502+
bw = bandwidth[i]
503+
504+
# Compute weights
505+
scaled = distances_i / bw
506+
weights = np.exp(-np.power(scaled, decay))
507+
weights = np.where(np.isnan(weights), 1.0, weights)
508+
509+
# Apply threshold
510+
mask = weights >= thresh
511+
row_masks.append(mask)
512+
row_kept_counts[i] = int(np.count_nonzero(mask))
513+
514+
# Compute indptr
515+
indptr = np.empty(n_rows + 1, dtype=np.int64)
516+
indptr[0] = 0
517+
np.cumsum(row_kept_counts, out=indptr[1:])
518+
nnz = int(indptr[-1])
519+
520+
# Allocate output arrays
521+
indices = np.empty(nnz, dtype=np.int32)
522+
data = np.empty(nnz, dtype=np.float64)
523+
524+
# Pass 2: fill arrays
525+
for i in range(n_rows):
526+
start, end = indptr[i], indptr[i + 1]
527+
if start == end:
528+
continue
529+
530+
mask = row_masks[i]
531+
neighbors_i = np.array(row_neighbors[i])[mask]
532+
distances_i = np.array(row_distances[i], dtype=np.float64)[mask]
533+
534+
if isinstance(bandwidth, numbers.Number):
535+
bw = bandwidth
536+
else:
537+
bw = bandwidth[i]
538+
539+
# Recompute weights for kept edges
540+
scaled = distances_i / bw
541+
weights_i = np.exp(-np.power(scaled, decay))
542+
weights_i = np.where(np.isnan(weights_i), 1.0, weights_i)
543+
544+
# Optional: sort by column index for canonical order
545+
if len(neighbors_i) > 1:
546+
sort_order = np.argsort(neighbors_i)
547+
neighbors_i = neighbors_i[sort_order]
548+
weights_i = weights_i[sort_order]
549+
550+
indices[start:end] = neighbors_i
551+
data[start:end] = weights_i
552+
553+
# Build CSR matrix
554+
K = sparse.csr_matrix((data, indices, indptr), shape=shape)
555+
556+
# Handle potential duplicates
557+
K.sum_duplicates()
558+
559+
return K
560+
561+
377562
class kNNGraph(DataGraph):
378563
"""
379564
K nearest neighbors graph
@@ -789,50 +974,11 @@ def build_kernel_to_data(
789974
for i, idx in enumerate(update_idx):
790975
distances[idx] = dist_new[i]
791976
indices[idx] = ind_new[i]
792-
# Scale distances and compute affinities
793-
if isinstance(bandwidth, numbers.Number):
794-
distances_flat = np.concatenate(distances)
795-
if NUMBA_AVAILABLE:
796-
data = _numba_scale_distances_single_bandwidth(
797-
distances_flat, bandwidth
798-
)
799-
data = _numba_compute_affinities(data, self.decay)
800-
data = _numba_threshold_affinities(data, self.thresh)
801-
else:
802-
data = distances_flat / bandwidth
803-
data = np.exp(-1 * np.power(data, self.decay))
804-
data = np.where(np.isnan(data), 1, data)
805-
data[data < self.thresh] = 0
806-
else:
807-
if NUMBA_AVAILABLE:
808-
# For variable bandwidth, we need to handle the scaling differently
809-
data = []
810-
for i in range(len(distances)):
811-
scaled = _numba_scale_distances_single_bandwidth(
812-
distances[i], bandwidth[i]
813-
)
814-
affinities = _numba_compute_affinities(scaled, self.decay)
815-
thresholded = _numba_threshold_affinities(
816-
affinities, self.thresh
817-
)
818-
data.append(thresholded)
819-
data = np.concatenate(data)
820-
else:
821-
data = np.concatenate(
822-
[distances[i] / bandwidth[i] for i in range(len(distances))]
823-
)
824-
data = np.exp(-1 * np.power(data, self.decay))
825-
data = np.where(np.isnan(data), 1, data)
826-
data[data < self.thresh] = 0
827-
828-
indices = np.concatenate(indices)
829-
indptr = np.concatenate([[0], np.cumsum([len(d) for d in distances])])
830-
K = sparse.csr_matrix(
831-
(data, indices, indptr), shape=(Y.shape[0], self.data_nu.shape[0])
977+
# Use optimized CSR construction to avoid COO conversions
978+
K = _build_csr_from_neighbors(
979+
indices, distances, bandwidth, self.decay, self.thresh,
980+
(Y.shape[0], self.data_nu.shape[0])
832981
)
833-
K = K.tocoo()
834-
K.eliminate_zeros()
835-
K = K.tocsr()
836982
return K
837983

838984

@@ -1457,9 +1603,8 @@ def build_kernel(self):
14571603
):
14581604
K = K.tocsr()
14591605
K.data[K.data < self.thresh] = 0
1460-
K = K.tocoo()
1606+
# Eliminate zeros directly on CSR - avoid unnecessary COO conversion
14611607
K.eliminate_zeros()
1462-
K = K.tocsr()
14631608
else:
14641609
K[K < self.thresh] = 0
14651610
return K

0 commit comments

Comments
 (0)