@@ -7,6 +7,7 @@ use ndarray::{Array1, Array2, ArrayView2};
77use rand:: seq:: SliceRandom ;
88use rand:: SeedableRng ;
99use rand_chacha:: ChaCha8Rng ;
10+ use rayon:: prelude:: * ;
1011use std:: time:: Instant ;
1112
1213/// Result of the k-means algorithm
@@ -81,53 +82,67 @@ pub fn kmeans_double_chunked(
8182 // Pre-compute centroid norms
8283 let centroid_norms = compute_squared_norms ( & centroids. view ( ) ) ;
8384
84- // Accumulators for new centroids
85- let mut cluster_sums: Array2 < f32 > = Array2 :: zeros ( ( k, n_features) ) ;
86- let mut cluster_counts: Array1 < f32 > = Array1 :: zeros ( k) ;
85+ // Build chunk ranges for parallel processing
86+ let chunk_ranges: Vec < ( usize , usize ) > = ( 0 ..n_samples_used)
87+ . step_by ( config. chunk_size_data )
88+ . map ( |start| ( start, ( start + config. chunk_size_data ) . min ( n_samples_used) ) )
89+ . collect ( ) ;
90+
91+ // Process all chunks in parallel
92+ #[ allow( clippy:: type_complexity) ]
93+ let chunk_results: Vec < ( Array2 < f32 > , Array1 < f32 > , Vec < ( usize , i64 ) > ) > = chunk_ranges
94+ . par_iter ( )
95+ . map ( |& ( start_idx, end_idx) | {
96+ let data_chunk = data_subset. slice ( ndarray:: s![ start_idx..end_idx, ..] ) ;
97+ let data_chunk_norms = data_norms. slice ( ndarray:: s![ start_idx..end_idx] ) ;
98+
99+ let chunk_labels = find_nearest_centroids_chunked (
100+ & data_chunk,
101+ & data_chunk_norms,
102+ & centroids. view ( ) ,
103+ & centroid_norms. view ( ) ,
104+ config. chunk_size_centroids ,
105+ ) ;
87106
88- // Process data in chunks
89- let mut start_idx = 0 ;
90- while start_idx < n_samples_used {
91- let end_idx = ( start_idx + config. chunk_size_data ) . min ( n_samples_used) ;
92- let data_chunk = data_subset. slice ( ndarray:: s![ start_idx..end_idx, ..] ) ;
93- let data_chunk_norms = data_norms. slice ( ndarray:: s![ start_idx..end_idx] ) ;
94-
95- // Find nearest centroids for this chunk
96- let chunk_labels = find_nearest_centroids_chunked (
97- & data_chunk,
98- & data_chunk_norms,
99- & centroids. view ( ) ,
100- & centroid_norms. view ( ) ,
101- config. chunk_size_centroids ,
102- ) ;
107+ let mut local_sums: Array2 < f32 > = Array2 :: zeros ( ( k, n_features) ) ;
108+ let mut local_counts: Array1 < f32 > = Array1 :: zeros ( k) ;
109+ let mut label_pairs: Vec < ( usize , i64 ) > = Vec :: with_capacity ( end_idx - start_idx) ;
110+
111+ for ( i, & label) in chunk_labels. iter ( ) . enumerate ( ) {
112+ let cluster_idx = label as usize ;
113+ local_counts[ cluster_idx] += 1.0 ;
114+ local_sums
115+ . row_mut ( cluster_idx)
116+ . scaled_add ( 1.0 , & data_chunk. row ( i) ) ;
117+ label_pairs. push ( ( start_idx + i, label) ) ;
118+ }
103119
104- // Update labels
105- for ( i, & label) in chunk_labels. iter ( ) . enumerate ( ) {
106- labels[ start_idx + i] = label;
107- }
120+ ( local_sums, local_counts, label_pairs)
121+ } )
122+ . collect ( ) ;
108123
109- // Accumulate cluster sums and counts
110- for ( i, & label) in chunk_labels. iter ( ) . enumerate ( ) {
111- let cluster_idx = label as usize ;
112- cluster_counts[ cluster_idx] += 1.0 ;
113- for j in 0 ..n_features {
114- cluster_sums[ [ cluster_idx, j] ] += data_chunk[ [ i, j] ] ;
115- }
116- }
124+ // Reduce: merge all chunk results
125+ let mut cluster_sums: Array2 < f32 > = Array2 :: zeros ( ( k, n_features) ) ;
126+ let mut cluster_counts: Array1 < f32 > = Array1 :: zeros ( k) ;
117127
118- start_idx = end_idx;
128+ for ( local_sums, local_counts, label_pairs) in chunk_results {
129+ cluster_sums += & local_sums;
130+ cluster_counts += & local_counts;
131+ for ( idx, label) in label_pairs {
132+ labels[ idx] = label;
133+ }
119134 }
120135
121- // Compute new centroids
136+ // Compute new centroids using vectorized operations
122137 let prev_centroids = centroids. clone ( ) ;
123138 let mut empty_clusters = Vec :: new ( ) ;
124139
125140 for cluster_idx in 0 ..k {
126141 let count = cluster_counts[ cluster_idx] ;
127142 if count > 0.0 {
128- for j in 0 ..n_features {
129- centroids [ [ cluster_idx , j ] ] = cluster_sums[ [ cluster_idx, j ] ] / count ;
130- }
143+ let mut centroid_row = centroids . row_mut ( cluster_idx ) ;
144+ let sum_row = cluster_sums. row ( cluster_idx) ;
145+ centroid_row . assign ( & ( & sum_row / count ) ) ;
131146 } else {
132147 empty_clusters. push ( cluster_idx) ;
133148 }
@@ -143,9 +158,9 @@ pub fn kmeans_double_chunked(
143158
144159 for ( i, & cluster_idx) in empty_clusters. iter ( ) . enumerate ( ) {
145160 let data_idx = random_indices[ i] ;
146- for j in 0 ..n_features {
147- centroids [ [ cluster_idx, j ] ] = data_subset [ [ data_idx , j ] ] ;
148- }
161+ centroids
162+ . row_mut ( cluster_idx)
163+ . assign ( & data_subset . row ( data_idx ) ) ;
149164 }
150165
151166 if config. verbose {
@@ -214,9 +229,7 @@ fn subsample_data(
214229 let n_features = data. ncols ( ) ;
215230 let mut subset = Array2 :: zeros ( ( max_samples, n_features) ) ;
216231 for ( new_idx, & old_idx) in indices. iter ( ) . enumerate ( ) {
217- for j in 0 ..n_features {
218- subset[ [ new_idx, j] ] = data[ [ old_idx, j] ] ;
219- }
232+ subset. row_mut ( new_idx) . assign ( & data. row ( old_idx) ) ;
220233 }
221234
222235 return Ok ( ( subset, Some ( indices) ) ) ;
@@ -237,9 +250,7 @@ fn initialize_centroids(data: &ArrayView2<f32>, k: usize, rng: &mut ChaCha8Rng)
237250
238251 let mut centroids = Array2 :: zeros ( ( k, n_features) ) ;
239252 for ( centroid_idx, & data_idx) in selected. iter ( ) . enumerate ( ) {
240- for j in 0 ..n_features {
241- centroids[ [ centroid_idx, j] ] = data[ [ data_idx, j] ] ;
242- }
253+ centroids. row_mut ( centroid_idx) . assign ( & data. row ( data_idx) ) ;
243254 }
244255
245256 centroids
0 commit comments