@@ -164,11 +164,31 @@ pub fn mat_svd(s: &mut Vector, u: &mut Matrix, vt: &mut Matrix, a: &mut Matrix)
164164 let lda = m_i32;
165165 let ldu = m_i32;
166166 let ldvt = n_i32;
167- const EXTRA : i32 = 1 ;
168- let lwork = 5 * to_i32 ( min_mn) + EXTRA ;
169- let mut work = vec ! [ 0.0 ; lwork as usize ] ;
170167 let mut info = 0 ;
171168 unsafe {
169+ // first: perform workspace query
170+ let lwork = -1 ; // to perform workspace query
171+ let mut work = vec ! [ 0.0 ] ; // will contain lwork on exit
172+ c_dgesvd (
173+ SVD_CODE_A ,
174+ SVD_CODE_A ,
175+ & m_i32,
176+ & n_i32,
177+ a. as_mut_data ( ) . as_mut_ptr ( ) ,
178+ & lda,
179+ s. as_mut_data ( ) . as_mut_ptr ( ) ,
180+ u. as_mut_data ( ) . as_mut_ptr ( ) ,
181+ & ldu,
182+ vt. as_mut_data ( ) . as_mut_ptr ( ) ,
183+ & ldvt,
184+ work. as_mut_ptr ( ) ,
185+ & lwork,
186+ & mut info,
187+ ) ;
188+
189+ // second: perform the SVG decomposition
190+ let lwork = work[ 0 ] as i32 ;
191+ let mut work = vec ! [ 0.0 ; lwork as usize ] ;
172192 c_dgesvd (
173193 SVD_CODE_A ,
174194 SVD_CODE_A ,
@@ -237,7 +257,7 @@ mod tests {
237257 }
238258
239259 #[ test]
240- fn mat_svd_works ( ) {
260+ fn mat_svd_4x3_works ( ) {
241261 // matrix
242262 let s33 = f64:: sqrt ( 3.0 ) / 3.0 ;
243263 #[ rustfmt:: skip]
@@ -282,7 +302,7 @@ mod tests {
282302 }
283303
284304 #[ test]
285- fn mat_svd_1_works ( ) {
305+ fn mat_svd_2x4_works ( ) {
286306 // matrix
287307 #[ rustfmt:: skip]
288308 let data = [
@@ -322,4 +342,43 @@ mod tests {
322342 }
323343 mat_approx_eq ( & usv, & a_copy, 1e-14 ) ;
324344 }
345+
346+ #[ test]
347+ fn mat_svd_1x4_works ( ) {
348+ // matrix
349+ #[ rustfmt:: skip]
350+ let data = [
351+ [ 0.25 , 0.25 , 0.25 , 0.25 ] ,
352+ ] ;
353+ let mut a = Matrix :: from ( & data) ;
354+ let a_copy = Matrix :: from ( & data) ;
355+
356+ // allocate output data
357+ let ( m, n) = a. dims ( ) ;
358+ let min_mn = if m < n { m } else { n } ;
359+ let mut s = Vector :: new ( min_mn) ;
360+ let mut u = Matrix :: new ( m, m) ;
361+ let mut vt = Matrix :: new ( n, n) ;
362+
363+ // calculate SVD
364+ mat_svd ( & mut s, & mut u, & mut vt, & mut a) . unwrap ( ) ;
365+
366+ // check S
367+ #[ rustfmt:: skip]
368+ let s_correct = & [
369+ 0.5 ,
370+ ] ;
371+ vec_approx_eq ( & s, s_correct, 1e-14 ) ;
372+
373+ // check SVD
374+ let mut usv = Matrix :: new ( m, n) ;
375+ for i in 0 ..m {
376+ for j in 0 ..n {
377+ for k in 0 ..min_mn {
378+ usv. add ( i, j, u. get ( i, k) * s[ k] * vt. get ( k, j) ) ;
379+ }
380+ }
381+ }
382+ mat_approx_eq ( & usv, & a_copy, 1e-14 ) ;
383+ }
325384}
0 commit comments