Skip to content

Commit e0a10f4

Browse files
raphaelstyclaude
andcommitted
Bump version to 0.1.5 with performance optimizations
- Add parallel chunk processing for large datasets (~2x speedup) - Use vectorized row operations for cluster sum accumulation - Optimize centroid updates with ndarray operations - Minor improvement in distance computation using row views Benchmarks (100k samples, 1k clusters): - Without BLAS: 1.65s -> 807ms (2x faster) - With accelerate: 450ms -> 307ms (1.5x faster) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fc29b46 commit e0a10f4

File tree

4 files changed

+60
-48
lines changed

4 files changed

+60
-48
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fastkmeans-rs"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
edition = "2021"
55
description = "A fast and efficient k-means clustering implementation in Rust, compatible with ndarray"
66
license = "Apache-2.0"

src/algorithm.rs

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use ndarray::{Array1, Array2, ArrayView2};
77
use rand::seq::SliceRandom;
88
use rand::SeedableRng;
99
use rand_chacha::ChaCha8Rng;
10+
use rayon::prelude::*;
1011
use std::time::Instant;
1112

1213
/// Result of the k-means algorithm
@@ -81,53 +82,67 @@ pub fn kmeans_double_chunked(
8182
// Pre-compute centroid norms
8283
let centroid_norms = compute_squared_norms(&centroids.view());
8384

84-
// Accumulators for new centroids
85-
let mut cluster_sums: Array2<f32> = Array2::zeros((k, n_features));
86-
let mut cluster_counts: Array1<f32> = Array1::zeros(k);
85+
// Build chunk ranges for parallel processing
86+
let chunk_ranges: Vec<(usize, usize)> = (0..n_samples_used)
87+
.step_by(config.chunk_size_data)
88+
.map(|start| (start, (start + config.chunk_size_data).min(n_samples_used)))
89+
.collect();
90+
91+
// Process all chunks in parallel
92+
#[allow(clippy::type_complexity)]
93+
let chunk_results: Vec<(Array2<f32>, Array1<f32>, Vec<(usize, i64)>)> = chunk_ranges
94+
.par_iter()
95+
.map(|&(start_idx, end_idx)| {
96+
let data_chunk = data_subset.slice(ndarray::s![start_idx..end_idx, ..]);
97+
let data_chunk_norms = data_norms.slice(ndarray::s![start_idx..end_idx]);
98+
99+
let chunk_labels = find_nearest_centroids_chunked(
100+
&data_chunk,
101+
&data_chunk_norms,
102+
&centroids.view(),
103+
&centroid_norms.view(),
104+
config.chunk_size_centroids,
105+
);
87106

88-
// Process data in chunks
89-
let mut start_idx = 0;
90-
while start_idx < n_samples_used {
91-
let end_idx = (start_idx + config.chunk_size_data).min(n_samples_used);
92-
let data_chunk = data_subset.slice(ndarray::s![start_idx..end_idx, ..]);
93-
let data_chunk_norms = data_norms.slice(ndarray::s![start_idx..end_idx]);
94-
95-
// Find nearest centroids for this chunk
96-
let chunk_labels = find_nearest_centroids_chunked(
97-
&data_chunk,
98-
&data_chunk_norms,
99-
&centroids.view(),
100-
&centroid_norms.view(),
101-
config.chunk_size_centroids,
102-
);
107+
let mut local_sums: Array2<f32> = Array2::zeros((k, n_features));
108+
let mut local_counts: Array1<f32> = Array1::zeros(k);
109+
let mut label_pairs: Vec<(usize, i64)> = Vec::with_capacity(end_idx - start_idx);
110+
111+
for (i, &label) in chunk_labels.iter().enumerate() {
112+
let cluster_idx = label as usize;
113+
local_counts[cluster_idx] += 1.0;
114+
local_sums
115+
.row_mut(cluster_idx)
116+
.scaled_add(1.0, &data_chunk.row(i));
117+
label_pairs.push((start_idx + i, label));
118+
}
103119

104-
// Update labels
105-
for (i, &label) in chunk_labels.iter().enumerate() {
106-
labels[start_idx + i] = label;
107-
}
120+
(local_sums, local_counts, label_pairs)
121+
})
122+
.collect();
108123

109-
// Accumulate cluster sums and counts
110-
for (i, &label) in chunk_labels.iter().enumerate() {
111-
let cluster_idx = label as usize;
112-
cluster_counts[cluster_idx] += 1.0;
113-
for j in 0..n_features {
114-
cluster_sums[[cluster_idx, j]] += data_chunk[[i, j]];
115-
}
116-
}
124+
// Reduce: merge all chunk results
125+
let mut cluster_sums: Array2<f32> = Array2::zeros((k, n_features));
126+
let mut cluster_counts: Array1<f32> = Array1::zeros(k);
117127

118-
start_idx = end_idx;
128+
for (local_sums, local_counts, label_pairs) in chunk_results {
129+
cluster_sums += &local_sums;
130+
cluster_counts += &local_counts;
131+
for (idx, label) in label_pairs {
132+
labels[idx] = label;
133+
}
119134
}
120135

121-
// Compute new centroids
136+
// Compute new centroids using vectorized operations
122137
let prev_centroids = centroids.clone();
123138
let mut empty_clusters = Vec::new();
124139

125140
for cluster_idx in 0..k {
126141
let count = cluster_counts[cluster_idx];
127142
if count > 0.0 {
128-
for j in 0..n_features {
129-
centroids[[cluster_idx, j]] = cluster_sums[[cluster_idx, j]] / count;
130-
}
143+
let mut centroid_row = centroids.row_mut(cluster_idx);
144+
let sum_row = cluster_sums.row(cluster_idx);
145+
centroid_row.assign(&(&sum_row / count));
131146
} else {
132147
empty_clusters.push(cluster_idx);
133148
}
@@ -143,9 +158,9 @@ pub fn kmeans_double_chunked(
143158

144159
for (i, &cluster_idx) in empty_clusters.iter().enumerate() {
145160
let data_idx = random_indices[i];
146-
for j in 0..n_features {
147-
centroids[[cluster_idx, j]] = data_subset[[data_idx, j]];
148-
}
161+
centroids
162+
.row_mut(cluster_idx)
163+
.assign(&data_subset.row(data_idx));
149164
}
150165

151166
if config.verbose {
@@ -214,9 +229,7 @@ fn subsample_data(
214229
let n_features = data.ncols();
215230
let mut subset = Array2::zeros((max_samples, n_features));
216231
for (new_idx, &old_idx) in indices.iter().enumerate() {
217-
for j in 0..n_features {
218-
subset[[new_idx, j]] = data[[old_idx, j]];
219-
}
232+
subset.row_mut(new_idx).assign(&data.row(old_idx));
220233
}
221234

222235
return Ok((subset, Some(indices)));
@@ -237,9 +250,7 @@ fn initialize_centroids(data: &ArrayView2<f32>, k: usize, rng: &mut ChaCha8Rng)
237250

238251
let mut centroids = Array2::zeros((k, n_features));
239252
for (centroid_idx, &data_idx) in selected.iter().enumerate() {
240-
for j in 0..n_features {
241-
centroids[[centroid_idx, j]] = data[[data_idx, j]];
242-
}
253+
centroids.row_mut(centroid_idx).assign(&data.row(data_idx));
243254
}
244255

245256
centroids

src/distance.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ pub fn find_nearest_centroids_chunked(
7474
// dist_chunk has shape (n_data, chunk_centroids)
7575
let n_centroids_chunk = c_end - c_start;
7676

77-
// Compute x.c using matrix multiplication
77+
// Compute x.c using matrix multiplication (BLAS accelerated)
7878
// data_chunk: (n_data, n_features), centroid_chunk.t(): (n_features, n_centroids_chunk)
7979
// Result: (n_data, n_centroids_chunk)
8080
let dot_products = data_chunk.dot(&centroid_chunk.t());
@@ -88,10 +88,11 @@ pub fn find_nearest_centroids_chunked(
8888
.enumerate()
8989
.for_each(|(i, (label, best_dist))| {
9090
let x_norm = data_norms[i];
91+
let dot_row = dot_products.row(i);
9192

9293
for j in 0..n_centroids_chunk {
9394
let c_norm = centroid_chunk_norms[j];
94-
let dot = dot_products[[i, j]];
95+
let dot = dot_row[j];
9596

9697
// Squared distance: ||x||^2 + ||c||^2 - 2*x.c
9798
let dist = x_norm + c_norm - 2.0 * dot;

0 commit comments

Comments
 (0)