Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions russell_lab/src/matrix/mat_pseudo_inverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,4 +407,22 @@ mod tests {
let a_ai_a = get_a_times_ai_times_a(&a_copy, &ai);
mat_approx_eq(&a_ai_a, &a_copy, 1e-13);
}

#[test]
fn mat_pseudo_inverse_1x4_works() {
#[rustfmt::skip]
let data = [
[0.25, 0.25, 0.25, 0.25],
];
let mut a = Matrix::from(&data);
let (m, n) = a.dims();
let mut ai = Matrix::new(n, m);
mat_pseudo_inverse(&mut ai, &mut a).unwrap();
let a_copy = Matrix::from(&data);
let a_ai_a = get_a_times_ai_times_a(&a_copy, &ai);
mat_approx_eq(&a_ai_a, &a_copy, 1e-13);

let ai_correct = Matrix::from(&[[1.0], [1.0], [1.0], [1.0]]);
mat_approx_eq(&ai, &ai_correct, 1e-13);
}
}
69 changes: 64 additions & 5 deletions russell_lab/src/matrix/mat_svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,31 @@ pub fn mat_svd(s: &mut Vector, u: &mut Matrix, vt: &mut Matrix, a: &mut Matrix)
let lda = m_i32;
let ldu = m_i32;
let ldvt = n_i32;
const EXTRA: i32 = 1;
let lwork = 5 * to_i32(min_mn) + EXTRA;
let mut work = vec![0.0; lwork as usize];
let mut info = 0;
unsafe {
// first: perform workspace query
let lwork = -1; // to perform workspace query
let mut work = vec![0.0]; // will contain lwork on exit
c_dgesvd(
SVD_CODE_A,
SVD_CODE_A,
&m_i32,
&n_i32,
a.as_mut_data().as_mut_ptr(),
&lda,
s.as_mut_data().as_mut_ptr(),
u.as_mut_data().as_mut_ptr(),
&ldu,
vt.as_mut_data().as_mut_ptr(),
&ldvt,
work.as_mut_ptr(),
&lwork,
&mut info,
);

// second: perform the SVG decomposition
let lwork = work[0] as i32;
let mut work = vec![0.0; lwork as usize];
c_dgesvd(
SVD_CODE_A,
SVD_CODE_A,
Expand Down Expand Up @@ -237,7 +257,7 @@ mod tests {
}

#[test]
fn mat_svd_works() {
fn mat_svd_4x3_works() {
// matrix
let s33 = f64::sqrt(3.0) / 3.0;
#[rustfmt::skip]
Expand Down Expand Up @@ -282,7 +302,7 @@ mod tests {
}

#[test]
fn mat_svd_1_works() {
fn mat_svd_2x4_works() {
// matrix
#[rustfmt::skip]
let data = [
Expand Down Expand Up @@ -322,4 +342,43 @@ mod tests {
}
mat_approx_eq(&usv, &a_copy, 1e-14);
}

#[test]
fn mat_svd_1x4_works() {
// matrix
#[rustfmt::skip]
let data = [
[0.25, 0.25, 0.25, 0.25],
];
let mut a = Matrix::from(&data);
let a_copy = Matrix::from(&data);

// allocate output data
let (m, n) = a.dims();
let min_mn = if m < n { m } else { n };
let mut s = Vector::new(min_mn);
let mut u = Matrix::new(m, m);
let mut vt = Matrix::new(n, n);

// calculate SVD
mat_svd(&mut s, &mut u, &mut vt, &mut a).unwrap();

// check S
#[rustfmt::skip]
let s_correct = &[
0.5,
];
vec_approx_eq(&s, s_correct, 1e-14);

// check SVD
let mut usv = Matrix::new(m, n);
for i in 0..m {
for j in 0..n {
for k in 0..min_mn {
usv.add(i, j, u.get(i, k) * s[k] * vt.get(k, j));
}
}
}
mat_approx_eq(&usv, &a_copy, 1e-14);
}
}