1+ use rayon:: iter:: ParallelIterator ;
2+ use std:: sync:: Arc ;
3+ use ndarray:: { s, Array1 , Array2 , ArrayView2 , Axis } ;
4+ use rayon:: iter:: IntoParallelIterator ;
5+
6+ // Trait for SVD implementations
7+ pub trait SVDImplementation : Send + Sync {
8+ fn compute ( & self , matrix : ArrayView2 < f64 > ) -> ( Array2 < f64 > , Array1 < f64 > , Array2 < f64 > ) ;
9+ }
10+
11+ pub struct PCABuilder < S : SVDImplementation > {
12+ n_components : Option < usize > ,
13+ center : bool ,
14+ scale : bool ,
15+ svd_implementation : Arc < S > ,
16+ }
17+
18+ impl < S : SVDImplementation > PCABuilder < S > {
19+ pub fn new ( svd_implementation : S ) -> Self {
20+ PCABuilder {
21+ n_components : None ,
22+ center : true ,
23+ scale : false ,
24+ svd_implementation : Arc :: new ( svd_implementation) ,
25+ }
26+ }
27+
28+ pub fn n_components ( mut self , n_components : usize ) -> Self {
29+ self . n_components = Some ( n_components) ;
30+ self
31+ }
32+
33+ pub fn center ( mut self , center : bool ) -> Self {
34+ self . center = center;
35+ self
36+ }
37+
38+ pub fn scale ( mut self , scale : bool ) -> Self {
39+ self . scale = scale;
40+ self
41+ }
42+
43+ pub fn build ( self ) -> Pca < S > {
44+ Pca {
45+ n_components : self . n_components ,
46+ center : self . center ,
47+ scale : self . scale ,
48+ svd_implementation : self . svd_implementation ,
49+ components : None ,
50+ mean : None ,
51+ std_dev : None ,
52+ explained_variance_ratio : None ,
53+ total_variance : None ,
54+ eigenvalues : None ,
55+ }
56+ }
57+ }
58+
59+ pub struct Pca < S : SVDImplementation > {
60+ n_components : Option < usize > ,
61+ center : bool ,
62+ scale : bool ,
63+ svd_implementation : Arc < S > ,
64+ components : Option < Array2 < f64 > > ,
65+ mean : Option < Array1 < f64 > > ,
66+ std_dev : Option < Array1 < f64 > > ,
67+ explained_variance_ratio : Option < Array1 < f64 > > ,
68+ total_variance : Option < f64 > ,
69+ eigenvalues : Option < Array1 < f64 > > ,
70+ }
71+
72+ impl < S : SVDImplementation > Pca < S > {
73+ pub fn fit ( & mut self , x : ArrayView2 < f64 > ) -> anyhow:: Result < ( ) > {
74+ let ( n_samples, n_features) = x. dim ( ) ;
75+ let n_components = self . n_components . unwrap_or ( n_features) ;
76+
77+ // Center the data
78+ let mean = if self . center {
79+ Some ( x. mean_axis ( Axis ( 0 ) ) . expect ( "Failed to compute mean" ) )
80+ } else {
81+ None
82+ } ;
83+
84+ // Scale the data
85+ let std_dev = if self . scale {
86+ Some ( x. std_axis ( Axis ( 0 ) , 0.0 ) )
87+ } else {
88+ None
89+ } ;
90+
91+ // Preprocess the data (center and scale)
92+ let x_preprocessed = self . preprocess ( x, & mean, & std_dev) ;
93+
94+ // Compute SVD using the provided implementation
95+ let ( _u, s, vt) = self . svd_implementation . compute ( x_preprocessed. view ( ) ) ;
96+
97+ // Extract principal components and eigenvalues
98+ let components = vt. slice ( s ! [ ..n_components, ..] ) . to_owned ( ) ;
99+
100+ let eigenvalues = s. mapv ( |x| x * x / ( n_samples as f64 - 1.0 ) ) ;
101+
102+ // Compute explained variance ratio
103+ let total_variance = eigenvalues. sum ( ) ;
104+ let explained_variance_ratio = & eigenvalues / total_variance;
105+
106+ // Store results
107+ self . components = Some ( components) ;
108+ self . mean = mean;
109+ self . std_dev = std_dev;
110+ self . explained_variance_ratio = Some (
111+ explained_variance_ratio
112+ . slice ( s ! [ ..n_components] )
113+ . to_owned ( ) ,
114+ ) ;
115+ self . total_variance = Some ( total_variance) ;
116+ self . eigenvalues = Some ( eigenvalues. slice ( s ! [ ..n_components] ) . to_owned ( ) ) ;
117+
118+ Ok ( ( ) )
119+ }
120+
121+ fn preprocess (
122+ & self ,
123+ x : ArrayView2 < f64 > ,
124+ mean : & Option < Array1 < f64 > > ,
125+ std_dev : & Option < Array1 < f64 > > ,
126+ ) -> Array2 < f64 > {
127+ let mut x_preprocessed = x. to_owned ( ) ;
128+
129+ // Center the data
130+ if let Some ( m) = mean {
131+ x_preprocessed
132+ . axis_iter_mut ( Axis ( 0 ) )
133+ . into_par_iter ( )
134+ . for_each ( |mut row| {
135+ row -= m;
136+ } ) ;
137+ }
138+
139+ // Scale the data
140+ if let Some ( s) = std_dev {
141+ x_preprocessed
142+ . axis_iter_mut ( Axis ( 0 ) )
143+ . into_par_iter ( )
144+ . for_each ( |mut row| {
145+ row /= s;
146+ } ) ;
147+ }
148+
149+ x_preprocessed
150+ }
151+
152+ pub fn transform ( & self , x : ArrayView2 < f64 > ) -> anyhow:: Result < Array2 < f64 > > {
153+ if let Some ( components) = & self . components {
154+ let x_preprocessed = self . preprocess ( x, & self . mean , & self . std_dev ) ;
155+
156+ // Ensure that we're using ArrayView2 for the dot product
157+ let x_preprocessed_view = x_preprocessed. view ( ) ;
158+ let components_view = components. view ( ) ;
159+ // Perform the matrix multiplication
160+ Ok ( x_preprocessed_view. dot ( & components_view. t ( ) ) )
161+ } else {
162+ Err ( anyhow:: anyhow!( "PCA has not been fitted yet" ) )
163+ }
164+ }
165+
166+ pub fn fit_transform ( & mut self , x : ArrayView2 < f64 > ) -> anyhow:: Result < Array2 < f64 > > {
167+ self . fit ( x) ?;
168+ self . transform ( x)
169+ }
170+
171+ // Getter methods for the computed values (unchanged)
172+ pub fn components ( & self ) -> Option < & Array2 < f64 > > {
173+ self . components . as_ref ( )
174+ }
175+
176+ pub fn explained_variance_ratio ( & self ) -> Option < & Array1 < f64 > > {
177+ self . explained_variance_ratio . as_ref ( )
178+ }
179+
180+ pub fn total_variance ( & self ) -> Option < f64 > {
181+ self . total_variance
182+ }
183+
184+ pub fn eigenvalues ( & self ) -> Option < & Array1 < f64 > > {
185+ self . eigenvalues . as_ref ( )
186+ }
187+ }
188+
189+ // Example implementation of the SVDImplementation trait
190+ #[ cfg( feature = "lapack" ) ]
191+ pub struct LapackSVD ;
192+
193+ #[ cfg( feature = "lapack" ) ]
194+ impl SVDImplementation for LapackSVD {
195+ fn compute ( & self , matrix : ArrayView2 < f64 > ) -> ( Array2 < f64 > , Array1 < f64 > , Array2 < f64 > ) {
196+ // This is where you'd implement the LAPACK SVD computation
197+ // For now, we'll just return dummy values
198+ let mut svd = crate :: svd:: lapack:: SVD :: new ( ) ;
199+ svd. compute ( matrix) . unwrap ( ) ;
200+ (
201+ svd. u ( ) . cloned ( ) . unwrap ( ) ,
202+ svd. s ( ) . cloned ( ) . unwrap ( ) ,
203+ svd. vt ( ) . cloned ( ) . unwrap ( ) ,
204+ )
205+ }
206+ }
207+
208+ #[ cfg( feature = "faer" ) ]
209+ pub struct FaerSVD ;
210+
211+ #[ cfg( feature = "faer" ) ]
212+ impl SVDImplementation for FaerSVD {
213+ fn compute ( & self , matrix : ArrayView2 < f64 > ) -> ( Array2 < f64 > , Array1 < f64 > , Array2 < f64 > ) {
214+ let svd = crate :: svd:: faer:: SVD :: new ( & matrix) ;
215+
216+ ( svd. u ( ) . clone ( ) , svd. s ( ) . clone ( ) , svd. vt ( ) . clone ( ) )
217+ }
218+ }
219+
220+ #[ cfg( test) ]
221+ mod tests {
222+ use ndarray:: array;
223+ use super :: PCABuilder ;
224+
225+ #[ cfg( feature = "faer" ) ]
226+ use super :: FaerSVD ;
227+
228+ #[ cfg( feature = "lapack" ) ]
229+ use super :: LapackSVD ;
230+
231+ #[ cfg( feature = "lapack" ) ]
232+ #[ test]
233+ fn test_pca_with_lapack_svd ( ) {
234+
235+
236+ let x = array ! [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] ] ;
237+ let mut pca = PCABuilder :: new ( LapackSVD ) . n_components ( 2 ) . build ( ) ;
238+
239+ pca. fit ( x. view ( ) ) . unwrap ( ) ;
240+
241+ assert ! ( pca. components( ) . is_some( ) ) ;
242+ }
243+
244+ #[ cfg( feature = "faer" ) ]
245+ #[ test]
246+ fn test_pca_with_faer_svd ( ) {
247+ let x = array ! [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] ] ;
248+ let mut pca = PCABuilder :: new ( FaerSVD ) . n_components ( 2 ) . build ( ) ;
249+
250+ pca. fit ( x. view ( ) ) . unwrap ( ) ;
251+
252+ assert ! ( pca. components( ) . is_some( ) ) ;
253+ }
254+
255+ #[ cfg( feature = "lapack" ) ]
256+ #[ test]
257+ fn test_pca_with_different_n_components_lap ( ) {
258+ let x = array ! [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 9.0 ] ] ;
259+ let mut pca = PCABuilder :: new ( LapackSVD ) . n_components ( 2 ) . build ( ) ;
260+
261+ // pca.fit(x.view()).unwrap();
262+ // let transformed = pca.transform(x.view()).unwrap();
263+
264+ // assert_eq!(transformed.shape(), &[3, 2]);
265+
266+ // // Test with n_components = 1
267+ let mut pca_1 = PCABuilder :: new ( LapackSVD ) . n_components ( 1 ) . build ( ) ;
268+ pca_1. fit ( x. view ( ) ) . unwrap ( ) ;
269+ let transformed_1 = pca_1. transform ( x. view ( ) ) . unwrap ( ) ;
270+ assert_eq ! ( transformed_1. shape( ) , & [ 3 , 1 ] ) ;
271+
272+ // Test with n_components = 3 (full dimensionality)
273+ let mut pca_3 = PCABuilder :: new ( LapackSVD ) . n_components ( 3 ) . build ( ) ;
274+ pca_3. fit ( x. view ( ) ) . unwrap ( ) ;
275+ let transformed_3 = pca_3. transform ( x. view ( ) ) . unwrap ( ) ;
276+ assert_eq ! ( transformed_3. shape( ) , & [ 3 , 3 ] ) ;
277+ }
278+
279+ #[ cfg( feature = "faer" ) ]
280+ #[ test]
281+ fn test_pca_with_different_n_components_faer ( ) {
282+ let x = array ! [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 9.0 ] ] ;
283+ let mut pca = PCABuilder :: new ( FaerSVD ) . n_components ( 2 ) . build ( ) ;
284+
285+ // pca.fit(x.view()).unwrap();
286+ // let transformed = pca.transform(x.view()).unwrap();
287+
288+ // assert_eq!(transformed.shape(), &[3, 2]);
289+
290+ // // Test with n_components = 1
291+ let mut pca_1 = PCABuilder :: new ( FaerSVD ) . n_components ( 1 ) . build ( ) ;
292+ pca_1. fit ( x. view ( ) ) . unwrap ( ) ;
293+ let transformed_1 = pca_1. transform ( x. view ( ) ) . unwrap ( ) ;
294+ assert_eq ! ( transformed_1. shape( ) , & [ 3 , 1 ] ) ;
295+
296+ // Test with n_components = 3 (full dimensionality)
297+ let mut pca_3 = PCABuilder :: new ( FaerSVD ) . n_components ( 3 ) . build ( ) ;
298+ pca_3. fit ( x. view ( ) ) . unwrap ( ) ;
299+ let transformed_3 = pca_3. transform ( x. view ( ) ) . unwrap ( ) ;
300+ assert_eq ! ( transformed_3. shape( ) , & [ 3 , 3 ] ) ;
301+ }
302+
303+ #[ test]
304+ #[ should_panic( expected = "PCA has not been fitted yet" ) ]
305+ #[ cfg( feature = "faer" ) ]
306+ fn test_pca_transform_without_fit ( ) {
307+ let x = array ! [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] , [ 5.0 , 6.0 ] ] ;
308+ let pca = PCABuilder :: new ( FaerSVD ) . n_components ( 2 ) . build ( ) ;
309+
310+ pca. transform ( x. view ( ) ) . unwrap ( ) ;
311+ }
312+ }
0 commit comments