1- use rayon:: iter:: IndexedParallelIterator ;
21use crate :: utils:: determine_chunk_size;
32use crate :: { SMat , SvdFloat } ;
43use nalgebra_sparse:: CsrMatrix ;
54use num_traits:: Float ;
5+ use rayon:: iter:: IndexedParallelIterator ;
66use rayon:: iter:: ParallelIterator ;
77use rayon:: prelude:: { IntoParallelIterator , ParallelBridge , ParallelSliceMut } ;
88use std:: ops:: AddAssign ;
@@ -62,7 +62,7 @@ impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
6262 }
6363}
6464
65- impl < ' a , T : Float + AddAssign + Sync + Send > SMat < T > for MaskedCSRMatrix < ' a , T > {
65+ impl < ' a , T : Float + AddAssign + Sync + Send + std :: ops :: MulAssign > SMat < T > for MaskedCSRMatrix < ' a , T > {
6666 fn nrows ( & self ) -> usize {
6767 self . matrix . nrows ( )
6868 }
@@ -121,29 +121,29 @@ impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T>
121121 // A * x calculation
122122 let row_count = self . matrix . nrows ( ) ;
123123 let ( major_offsets, minor_indices, values) = self . matrix . csr_data ( ) ;
124-
124+
125125 let chunk_size = std:: cmp:: max ( 16 , row_count / ( rayon:: current_num_threads ( ) * 2 ) ) ;
126-
126+
127127 let mut valid_indices = Vec :: with_capacity ( self . matrix . ncols ( ) ) ;
128128 for col in 0 ..self . matrix . ncols ( ) {
129129 valid_indices. push ( self . original_to_masked [ col] ) ;
130130 }
131-
131+
132132 y. par_chunks_mut ( chunk_size)
133133 . enumerate ( )
134134 . for_each ( |( chunk_idx, y_chunk) | {
135135 let start_row = chunk_idx * chunk_size;
136136 let end_row = ( start_row + y_chunk. len ( ) ) . min ( row_count) ;
137-
137+
138138 for i in start_row..end_row {
139139 let row_idx = i - start_row;
140140 let mut sum = T :: zero ( ) ;
141-
141+
142142 let row_start = major_offsets[ i] ;
143143 let row_end = major_offsets[ i + 1 ] ;
144-
144+
145145 let mut j = row_start;
146-
146+
147147 while j + 4 <= row_end {
148148 for offset in 0 ..4 {
149149 let idx = j + offset;
@@ -154,7 +154,7 @@ impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T>
154154 }
155155 j += 4 ;
156156 }
157-
157+
158158 while j < row_end {
159159 let col = minor_indices[ j] ;
160160 if let Some ( masked_col) = valid_indices[ col] {
@@ -202,12 +202,37 @@ impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T>
202202 }
203203 }
204204 }
205+
206+ fn compute_column_means ( & self ) -> Vec < T > {
207+ let rows = self . nrows ( ) ;
208+ let masked_cols = self . ncols ( ) ;
209+ let row_count_recip = T :: one ( ) / T :: from ( rows) . unwrap ( ) ;
210+
211+ let mut col_sums = vec ! [ T :: zero( ) ; masked_cols] ;
212+ let ( row_offsets, col_indices, values) = self . matrix . csr_data ( ) ;
213+
214+ for i in 0 ..rows {
215+ for j in row_offsets[ i] ..row_offsets[ i + 1 ] {
216+ let original_col = col_indices[ j] ;
217+ if let Some ( masked_col) = self . original_to_masked [ original_col] {
218+ col_sums[ masked_col] += values[ j] ;
219+ }
220+ }
221+ }
222+
223+ // Convert to means
224+ for j in 0 ..masked_cols {
225+ col_sums[ j] *= row_count_recip;
226+ }
227+
228+ col_sums
229+ }
205230}
206231
207232#[ cfg( test) ]
208233mod tests {
209234 use super :: * ;
210- use crate :: { SMat } ;
235+ use crate :: SMat ;
211236 use nalgebra_sparse:: { coo:: CooMatrix , csr:: CsrMatrix } ;
212237 use rand:: rngs:: StdRng ;
213238 use rand:: { Rng , SeedableRng } ;
0 commit comments