Skip to content

Commit a88d4cb

Browse files
authored
Merge pull request #132 from cpmech/improve-pseudo-inverse
Let LAPACK automatically determine lwork for the SVD routine
2 parents 844ea0a + 52743d7 commit a88d4cb

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

russell_lab/src/matrix/mat_pseudo_inverse.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,4 +407,22 @@ mod tests {
407407
let a_ai_a = get_a_times_ai_times_a(&a_copy, &ai);
408408
mat_approx_eq(&a_ai_a, &a_copy, 1e-13);
409409
}
410+
411+
#[test]
412+
fn mat_pseudo_inverse_1x4_works() {
413+
#[rustfmt::skip]
414+
let data = [
415+
[0.25, 0.25, 0.25, 0.25],
416+
];
417+
let mut a = Matrix::from(&data);
418+
let (m, n) = a.dims();
419+
let mut ai = Matrix::new(n, m);
420+
mat_pseudo_inverse(&mut ai, &mut a).unwrap();
421+
let a_copy = Matrix::from(&data);
422+
let a_ai_a = get_a_times_ai_times_a(&a_copy, &ai);
423+
mat_approx_eq(&a_ai_a, &a_copy, 1e-13);
424+
425+
let ai_correct = Matrix::from(&[[1.0], [1.0], [1.0], [1.0]]);
426+
mat_approx_eq(&ai, &ai_correct, 1e-13);
427+
}
410428
}

russell_lab/src/matrix/mat_svd.rs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,31 @@ pub fn mat_svd(s: &mut Vector, u: &mut Matrix, vt: &mut Matrix, a: &mut Matrix)
164164
let lda = m_i32;
165165
let ldu = m_i32;
166166
let ldvt = n_i32;
167-
const EXTRA: i32 = 1;
168-
let lwork = 5 * to_i32(min_mn) + EXTRA;
169-
let mut work = vec![0.0; lwork as usize];
170167
let mut info = 0;
171168
unsafe {
169+
// first: perform workspace query
170+
let lwork = -1; // to perform workspace query
171+
let mut work = vec![0.0]; // will contain lwork on exit
172+
c_dgesvd(
173+
SVD_CODE_A,
174+
SVD_CODE_A,
175+
&m_i32,
176+
&n_i32,
177+
a.as_mut_data().as_mut_ptr(),
178+
&lda,
179+
s.as_mut_data().as_mut_ptr(),
180+
u.as_mut_data().as_mut_ptr(),
181+
&ldu,
182+
vt.as_mut_data().as_mut_ptr(),
183+
&ldvt,
184+
work.as_mut_ptr(),
185+
&lwork,
186+
&mut info,
187+
);
188+
189+
// second: perform the SVG decomposition
190+
let lwork = work[0] as i32;
191+
let mut work = vec![0.0; lwork as usize];
172192
c_dgesvd(
173193
SVD_CODE_A,
174194
SVD_CODE_A,
@@ -237,7 +257,7 @@ mod tests {
237257
}
238258

239259
#[test]
240-
fn mat_svd_works() {
260+
fn mat_svd_4x3_works() {
241261
// matrix
242262
let s33 = f64::sqrt(3.0) / 3.0;
243263
#[rustfmt::skip]
@@ -282,7 +302,7 @@ mod tests {
282302
}
283303

284304
#[test]
285-
fn mat_svd_1_works() {
305+
fn mat_svd_2x4_works() {
286306
// matrix
287307
#[rustfmt::skip]
288308
let data = [
@@ -322,4 +342,43 @@ mod tests {
322342
}
323343
mat_approx_eq(&usv, &a_copy, 1e-14);
324344
}
345+
346+
#[test]
347+
fn mat_svd_1x4_works() {
348+
// matrix
349+
#[rustfmt::skip]
350+
let data = [
351+
[0.25, 0.25, 0.25, 0.25],
352+
];
353+
let mut a = Matrix::from(&data);
354+
let a_copy = Matrix::from(&data);
355+
356+
// allocate output data
357+
let (m, n) = a.dims();
358+
let min_mn = if m < n { m } else { n };
359+
let mut s = Vector::new(min_mn);
360+
let mut u = Matrix::new(m, m);
361+
let mut vt = Matrix::new(n, n);
362+
363+
// calculate SVD
364+
mat_svd(&mut s, &mut u, &mut vt, &mut a).unwrap();
365+
366+
// check S
367+
#[rustfmt::skip]
368+
let s_correct = &[
369+
0.5,
370+
];
371+
vec_approx_eq(&s, s_correct, 1e-14);
372+
373+
// check SVD
374+
let mut usv = Matrix::new(m, n);
375+
for i in 0..m {
376+
for j in 0..n {
377+
for k in 0..min_mn {
378+
usv.add(i, j, u.get(i, k) * s[k] * vt.get(k, j));
379+
}
380+
}
381+
}
382+
mat_approx_eq(&usv, &a_copy, 1e-14);
383+
}
325384
}

0 commit comments

Comments
 (0)