Skip to content

Commit 3382a04

Browse files
authored
Merge pull request #22 from SingleRust/feature-sparse-pca-revamp
Add MaskedSparsePCA and SparsePCA builders
2 parents a60c59d + 1f6f803 commit 3382a04

File tree

6 files changed

+845
-545
lines changed

6 files changed

+845
-545
lines changed

Cargo.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "single_algebra"
3-
version = "0.2.2-alpha.0"
3+
version = "0.2.3-alpha.0"
44
edition = "2021"
55
license-file = "LICENSE.md"
66
description = "A linear algebra convenience library for the single-rust library. Can be used externally as well."
@@ -44,7 +44,7 @@ num-traits = "0.2.19"
4444
rayon = "1.10.0"
4545
simba = { version = "0.9.0", optional = true }
4646
smartcore = { version = "0.4", features = ["ndarray-bindings"], optional = true }
47-
single-svdlib = "0.1.0"
47+
single-svdlib = "0.2.0"
4848
parking_lot = "0.12.3"
4949
petgraph = { version = "0.7.1", features = ["rayon"] }
5050
rand = "0.9.0"

src/dimred/pca/dense/mod.rs

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
use rayon::iter::ParallelIterator;
2+
use std::sync::Arc;
3+
use ndarray::{s, Array1, Array2, ArrayView2, Axis};
4+
use rayon::iter::IntoParallelIterator;
5+
6+
// Trait for SVD implementations
7+
pub trait SVDImplementation: Send + Sync {
8+
fn compute(&self, matrix: ArrayView2<f64>) -> (Array2<f64>, Array1<f64>, Array2<f64>);
9+
}
10+
11+
pub struct PCABuilder<S: SVDImplementation> {
12+
n_components: Option<usize>,
13+
center: bool,
14+
scale: bool,
15+
svd_implementation: Arc<S>,
16+
}
17+
18+
impl<S: SVDImplementation> PCABuilder<S> {
19+
pub fn new(svd_implementation: S) -> Self {
20+
PCABuilder {
21+
n_components: None,
22+
center: true,
23+
scale: false,
24+
svd_implementation: Arc::new(svd_implementation),
25+
}
26+
}
27+
28+
pub fn n_components(mut self, n_components: usize) -> Self {
29+
self.n_components = Some(n_components);
30+
self
31+
}
32+
33+
pub fn center(mut self, center: bool) -> Self {
34+
self.center = center;
35+
self
36+
}
37+
38+
pub fn scale(mut self, scale: bool) -> Self {
39+
self.scale = scale;
40+
self
41+
}
42+
43+
pub fn build(self) -> Pca<S> {
44+
Pca {
45+
n_components: self.n_components,
46+
center: self.center,
47+
scale: self.scale,
48+
svd_implementation: self.svd_implementation,
49+
components: None,
50+
mean: None,
51+
std_dev: None,
52+
explained_variance_ratio: None,
53+
total_variance: None,
54+
eigenvalues: None,
55+
}
56+
}
57+
}
58+
59+
pub struct Pca<S: SVDImplementation> {
60+
n_components: Option<usize>,
61+
center: bool,
62+
scale: bool,
63+
svd_implementation: Arc<S>,
64+
components: Option<Array2<f64>>,
65+
mean: Option<Array1<f64>>,
66+
std_dev: Option<Array1<f64>>,
67+
explained_variance_ratio: Option<Array1<f64>>,
68+
total_variance: Option<f64>,
69+
eigenvalues: Option<Array1<f64>>,
70+
}
71+
72+
impl<S: SVDImplementation> Pca<S> {
73+
pub fn fit(&mut self, x: ArrayView2<f64>) -> anyhow::Result<()> {
74+
let (n_samples, n_features) = x.dim();
75+
let n_components = self.n_components.unwrap_or(n_features);
76+
77+
// Center the data
78+
let mean = if self.center {
79+
Some(x.mean_axis(Axis(0)).expect("Failed to compute mean"))
80+
} else {
81+
None
82+
};
83+
84+
// Scale the data
85+
let std_dev = if self.scale {
86+
Some(x.std_axis(Axis(0), 0.0))
87+
} else {
88+
None
89+
};
90+
91+
// Preprocess the data (center and scale)
92+
let x_preprocessed = self.preprocess(x, &mean, &std_dev);
93+
94+
// Compute SVD using the provided implementation
95+
let (_u, s, vt) = self.svd_implementation.compute(x_preprocessed.view());
96+
97+
// Extract principal components and eigenvalues
98+
let components = vt.slice(s![..n_components, ..]).to_owned();
99+
100+
let eigenvalues = s.mapv(|x| x * x / (n_samples as f64 - 1.0));
101+
102+
// Compute explained variance ratio
103+
let total_variance = eigenvalues.sum();
104+
let explained_variance_ratio = &eigenvalues / total_variance;
105+
106+
// Store results
107+
self.components = Some(components);
108+
self.mean = mean;
109+
self.std_dev = std_dev;
110+
self.explained_variance_ratio = Some(
111+
explained_variance_ratio
112+
.slice(s![..n_components])
113+
.to_owned(),
114+
);
115+
self.total_variance = Some(total_variance);
116+
self.eigenvalues = Some(eigenvalues.slice(s![..n_components]).to_owned());
117+
118+
Ok(())
119+
}
120+
121+
fn preprocess(
122+
&self,
123+
x: ArrayView2<f64>,
124+
mean: &Option<Array1<f64>>,
125+
std_dev: &Option<Array1<f64>>,
126+
) -> Array2<f64> {
127+
let mut x_preprocessed = x.to_owned();
128+
129+
// Center the data
130+
if let Some(m) = mean {
131+
x_preprocessed
132+
.axis_iter_mut(Axis(0))
133+
.into_par_iter()
134+
.for_each(|mut row| {
135+
row -= m;
136+
});
137+
}
138+
139+
// Scale the data
140+
if let Some(s) = std_dev {
141+
x_preprocessed
142+
.axis_iter_mut(Axis(0))
143+
.into_par_iter()
144+
.for_each(|mut row| {
145+
row /= s;
146+
});
147+
}
148+
149+
x_preprocessed
150+
}
151+
152+
pub fn transform(&self, x: ArrayView2<f64>) -> anyhow::Result<Array2<f64>> {
153+
if let Some(components) = &self.components {
154+
let x_preprocessed = self.preprocess(x, &self.mean, &self.std_dev);
155+
156+
// Ensure that we're using ArrayView2 for the dot product
157+
let x_preprocessed_view = x_preprocessed.view();
158+
let components_view = components.view();
159+
// Perform the matrix multiplication
160+
Ok(x_preprocessed_view.dot(&components_view.t()))
161+
} else {
162+
Err(anyhow::anyhow!("PCA has not been fitted yet"))
163+
}
164+
}
165+
166+
pub fn fit_transform(&mut self, x: ArrayView2<f64>) -> anyhow::Result<Array2<f64>> {
167+
self.fit(x)?;
168+
self.transform(x)
169+
}
170+
171+
// Getter methods for the computed values (unchanged)
172+
pub fn components(&self) -> Option<&Array2<f64>> {
173+
self.components.as_ref()
174+
}
175+
176+
pub fn explained_variance_ratio(&self) -> Option<&Array1<f64>> {
177+
self.explained_variance_ratio.as_ref()
178+
}
179+
180+
pub fn total_variance(&self) -> Option<f64> {
181+
self.total_variance
182+
}
183+
184+
pub fn eigenvalues(&self) -> Option<&Array1<f64>> {
185+
self.eigenvalues.as_ref()
186+
}
187+
}
188+
189+
// Example implementation of the SVDImplementation trait
190+
#[cfg(feature = "lapack")]
191+
pub struct LapackSVD;
192+
193+
#[cfg(feature = "lapack")]
194+
impl SVDImplementation for LapackSVD {
195+
fn compute(&self, matrix: ArrayView2<f64>) -> (Array2<f64>, Array1<f64>, Array2<f64>) {
196+
// This is where you'd implement the LAPACK SVD computation
197+
// For now, we'll just return dummy values
198+
let mut svd = crate::svd::lapack::SVD::new();
199+
svd.compute(matrix).unwrap();
200+
(
201+
svd.u().cloned().unwrap(),
202+
svd.s().cloned().unwrap(),
203+
svd.vt().cloned().unwrap(),
204+
)
205+
}
206+
}
207+
208+
#[cfg(feature = "faer")]
209+
pub struct FaerSVD;
210+
211+
#[cfg(feature = "faer")]
212+
impl SVDImplementation for FaerSVD {
213+
fn compute(&self, matrix: ArrayView2<f64>) -> (Array2<f64>, Array1<f64>, Array2<f64>) {
214+
let svd = crate::svd::faer::SVD::new(&matrix);
215+
216+
(svd.u().clone(), svd.s().clone(), svd.vt().clone())
217+
}
218+
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use ndarray::array;
223+
use super::PCABuilder;
224+
225+
#[cfg(feature = "faer")]
226+
use super::FaerSVD;
227+
228+
#[cfg(feature = "lapack")]
229+
use super::LapackSVD;
230+
231+
#[cfg(feature = "lapack")]
232+
#[test]
233+
fn test_pca_with_lapack_svd() {
234+
235+
236+
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
237+
let mut pca = PCABuilder::new(LapackSVD).n_components(2).build();
238+
239+
pca.fit(x.view()).unwrap();
240+
241+
assert!(pca.components().is_some());
242+
}
243+
244+
#[cfg(feature = "faer")]
245+
#[test]
246+
fn test_pca_with_faer_svd() {
247+
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
248+
let mut pca = PCABuilder::new(FaerSVD).n_components(2).build();
249+
250+
pca.fit(x.view()).unwrap();
251+
252+
assert!(pca.components().is_some());
253+
}
254+
255+
#[cfg(feature = "lapack")]
256+
#[test]
257+
fn test_pca_with_different_n_components_lap() {
258+
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
259+
let mut pca = PCABuilder::new(LapackSVD).n_components(2).build();
260+
261+
// pca.fit(x.view()).unwrap();
262+
// let transformed = pca.transform(x.view()).unwrap();
263+
264+
// assert_eq!(transformed.shape(), &[3, 2]);
265+
266+
// // Test with n_components = 1
267+
let mut pca_1 = PCABuilder::new(LapackSVD).n_components(1).build();
268+
pca_1.fit(x.view()).unwrap();
269+
let transformed_1 = pca_1.transform(x.view()).unwrap();
270+
assert_eq!(transformed_1.shape(), &[3, 1]);
271+
272+
// Test with n_components = 3 (full dimensionality)
273+
let mut pca_3 = PCABuilder::new(LapackSVD).n_components(3).build();
274+
pca_3.fit(x.view()).unwrap();
275+
let transformed_3 = pca_3.transform(x.view()).unwrap();
276+
assert_eq!(transformed_3.shape(), &[3, 3]);
277+
}
278+
279+
#[cfg(feature = "faer")]
280+
#[test]
281+
fn test_pca_with_different_n_components_faer() {
282+
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
283+
let mut pca = PCABuilder::new(FaerSVD).n_components(2).build();
284+
285+
// pca.fit(x.view()).unwrap();
286+
// let transformed = pca.transform(x.view()).unwrap();
287+
288+
// assert_eq!(transformed.shape(), &[3, 2]);
289+
290+
// // Test with n_components = 1
291+
let mut pca_1 = PCABuilder::new(FaerSVD).n_components(1).build();
292+
pca_1.fit(x.view()).unwrap();
293+
let transformed_1 = pca_1.transform(x.view()).unwrap();
294+
assert_eq!(transformed_1.shape(), &[3, 1]);
295+
296+
// Test with n_components = 3 (full dimensionality)
297+
let mut pca_3 = PCABuilder::new(FaerSVD).n_components(3).build();
298+
pca_3.fit(x.view()).unwrap();
299+
let transformed_3 = pca_3.transform(x.view()).unwrap();
300+
assert_eq!(transformed_3.shape(), &[3, 3]);
301+
}
302+
303+
#[test]
304+
#[should_panic(expected = "PCA has not been fitted yet")]
305+
#[cfg(feature = "faer")]
306+
fn test_pca_transform_without_fit() {
307+
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
308+
let pca = PCABuilder::new(FaerSVD).n_components(2).build();
309+
310+
pca.transform(x.view()).unwrap();
311+
}
312+
}

0 commit comments

Comments
 (0)