@@ -16,17 +16,19 @@ const ungqr! = orgqr!
1616
1717# Wrapper for SVD via QR Iteration
1818for (fname, elty, relty) in
19- ((:rocsolver_sgesvd, :Float32, :Float32),
20- (:rocsolver_dgesvd, :Float64, :Float64),
21- (:rocsolver_cgesvd, :ComplexF32, :Float32),
22- (:rocsolver_zgesvd, :ComplexF64, :Float64))
19+ (
20+ (:rocsolver_sgesvd, :Float32, :Float32),
21+ (:rocsolver_dgesvd, :Float64, :Float64),
22+ (:rocsolver_cgesvd, :ComplexF32, :Float32),
23+ (:rocsolver_zgesvd, :ComplexF64, :Float64),
24+ )
2325 @eval begin
24- # ! format: off
25- function gesvd!( A:: StridedROCMatrix{$elty} ,
26- S:: StridedROCVector{$relty} = similar(A, $ relty, min(size(A). .. )),
27- U:: StridedROCMatrix{$elty} = similar(A, $ elty, size(A, 1 ), min(size(A). .. )),
28- Vᴴ:: StridedROCMatrix{$elty} = similar(A, $ elty, min(size(A). .. ), size(A, 2 ) ))
29- # ! format: on
26+ function gesvd!(
27+ A:: StridedROCMatrix{$elty} ,
28+ S:: StridedROCVector{$relty} = similar(A, $ relty, min(size(A). .. )),
29+ U:: StridedROCMatrix{$elty} = similar(A, $ elty, size(A, 1 ), min(size(A). .. )),
30+ Vᴴ:: StridedROCMatrix{$elty} = similar(A, $ elty, min(size(A). .. ), size(A, 2 ))
31+ )
3032 chkstride1(A, U, Vᴴ, S)
3133 m, n = size(A)
3234 (m < n) && throw(ArgumentError(" rocSOLVER's gesvd requires m ≥ n" ))
@@ -72,13 +74,15 @@ for (fname, elty, relty) in
7274 ldu = max(1 , stride(U, 2 ))
7375 ldv = max(1 , stride(Vᴴ, 2 ))
7476
75- rwork = ROCArray{$ relty}(undef, minmn - 1 )
76- dh = rocBLAS. handle()
77+ rwork = ROCArray{$ relty}(undef, minmn - 1 )
78+ dh = rocBLAS. handle()
7779 dev_info = ROCVector{Cint}(undef, 1 )
78- rocSOLVER.$ fname(dh, jobu, jobvt, m, n,
79- A, lda, S, U, ldu, Vᴴ, ldv,
80- rwork, convert(rocSOLVER. rocblas_workmode, ' I' ),
81- dev_info)
80+ rocSOLVER.$ fname(
81+ dh, jobu, jobvt, m, n,
82+ A, lda, S, U, ldu, Vᴴ, ldv,
83+ rwork, convert(rocSOLVER. rocblas_workmode, ' I' ),
84+ dev_info
85+ )
8286 AMDGPU. unsafe_free!(rwork)
8387
8488 info = @allowscalar dev_info[1 ]
9195
9296# Wrapper for SVD via Jacobi
9397for (fname, elty, relty) in
94- ((:rocsolver_sgesvdj, :Float32, :Float32),
95- (:rocsolver_dgesvdj, :Float64, :Float64),
96- (:rocsolver_cgesvdj, :ComplexF32, :Float32),
97- (:rocsolver_zgesvdj, :ComplexF64, :Float64))
98+ (
99+ (:rocsolver_sgesvdj, :Float32, :Float32),
100+ (:rocsolver_dgesvdj, :Float64, :Float64),
101+ (:rocsolver_cgesvdj, :ComplexF32, :Float32),
102+ (:rocsolver_zgesvdj, :ComplexF64, :Float64),
103+ )
98104 @eval begin
99- # ! format: off
100- function gesvdj!(A:: StridedROCMatrix{$elty} ,
101- S:: StridedROCVector{$relty} = similar(A, $ relty, min(size(A). .. )),
102- U:: StridedROCMatrix{$elty} = similar(A, $ elty, size(A, 1 ), min(size(A). .. )),
103- Vᴴ:: StridedROCMatrix{$elty} = similar(A, $ elty, min(size(A). .. ), size(A, 2 ));
104- tol:: $relty = eps($ relty),
105- max_sweeps:: Int = 100 ,
106- )
107- # ! format: on
105+ function gesvdj!(
106+ A:: StridedROCMatrix{$elty} ,
107+ S:: StridedROCVector{$relty} = similar(A, $ relty, min(size(A). .. )),
108+ U:: StridedROCMatrix{$elty} = similar(A, $ elty, size(A, 1 ), min(size(A). .. )),
109+ Vᴴ:: StridedROCMatrix{$elty} = similar(A, $ elty, min(size(A). .. ), size(A, 2 ));
110+ tol:: $relty = eps($ relty),
111+ max_sweeps:: Int = 100 ,
112+ )
108113 chkstride1(A, U, Vᴴ, S)
109114 m, n = size(A)
110115 minmn = min(m, n)
@@ -149,21 +154,22 @@ for (fname, elty, relty) in
149154 lda = max(1 , stride(A, 2 ))
150155 ldu = max(1 , stride(U, 2 ))
151156 ldv = max(1 , stride(Vᴴ, 2 ))
152- dev_info = ROCVector{Cint}(undef, 1 )
157+ dev_info = ROCVector{Cint}(undef, 1 )
153158 dev_residual = ROCVector{$ relty}(undef, 1 )
154159 dev_n_sweeps = ROCVector{Cint}(undef, 1 )
155160
156161 dh = rocBLAS. handle()
157- rocSOLVER.$ fname(dh, jobu, jobvt, m, n, A, lda, tol,
158- dev_residual, max_sweeps, dev_n_sweeps,
159- S, U, ldu, Vᴴ, ldv, dev_info,
160- )
162+ rocSOLVER.$ fname(
163+ dh, jobu, jobvt, m, n, A, lda, tol,
164+ dev_residual, max_sweeps, dev_n_sweeps,
165+ S, U, ldu, Vᴴ, ldv, dev_info,
166+ )
161167
162168 info = @allowscalar dev_info[1 ]
163169 rocSOLVER. chkargsok(BlasInt(info))
164170
165- AMDGPU. unsafe_free!(dev_residual)
166- AMDGPU. unsafe_free!(dev_n_sweeps)
171+ AMDGPU. unsafe_free!(dev_residual)
172+ AMDGPU. unsafe_free!(dev_n_sweeps)
167173 return (S, U, Vᴴ)
168174 end
169175 end
@@ -476,15 +482,19 @@ end
476482# end
477483
478484for (heevd, heev, heevx, heevj, elty, relty) in
479- ((:(rocSOLVER. rocsolver_ssyevd), :(rocSOLVER. rocsolver_ssyev), :(rocSOLVER. rocsolver_ssyevx), :(rocSOLVER. rocsolver_ssyevj), :Float32, :Float32),
480- (:(rocSOLVER. rocsolver_dsyevd), :(rocSOLVER. rocsolver_dsyev), :(rocSOLVER. rocsolver_dsyevx), :(rocSOLVER. rocsolver_dsyevj), :Float64, :Float64),
481- (:(rocSOLVER. rocsolver_cheevd), :(rocSOLVER. rocsolver_cheev), :(rocSOLVER. rocsolver_cheevx), :(rocSOLVER. rocsolver_cheevj), :ComplexF32, :Float32),
482- (:(rocSOLVER. rocsolver_zheevd), :(rocSOLVER. rocsolver_zheev), :(rocSOLVER. rocsolver_zheevx), :(rocSOLVER. rocsolver_zheevj), :ComplexF64, :Float64))
485+ (
486+ (:(rocSOLVER. rocsolver_ssyevd), :(rocSOLVER. rocsolver_ssyev), :(rocSOLVER. rocsolver_ssyevx), :(rocSOLVER. rocsolver_ssyevj), :Float32, :Float32),
487+ (:(rocSOLVER. rocsolver_dsyevd), :(rocSOLVER. rocsolver_dsyev), :(rocSOLVER. rocsolver_dsyevx), :(rocSOLVER. rocsolver_dsyevj), :Float64, :Float64),
488+ (:(rocSOLVER. rocsolver_cheevd), :(rocSOLVER. rocsolver_cheev), :(rocSOLVER. rocsolver_cheevx), :(rocSOLVER. rocsolver_cheevj), :ComplexF32, :Float32),
489+ (:(rocSOLVER. rocsolver_zheevd), :(rocSOLVER. rocsolver_zheev), :(rocSOLVER. rocsolver_zheevx), :(rocSOLVER. rocsolver_zheevj), :ComplexF64, :Float64),
490+ )
483491 @eval begin
484- function heevd!(A:: StridedROCMatrix{$elty} ,
485- W:: StridedROCVector{$relty} ,
486- V:: StridedROCMatrix{$elty} ;
487- uplo:: Char = ' U' )
492+ function heevd!(
493+ A:: StridedROCMatrix{$elty} ,
494+ W:: StridedROCVector{$relty} ,
495+ V:: StridedROCMatrix{$elty} ;
496+ uplo:: Char = ' U'
497+ )
488498 chkuplo(uplo)
489499 n = checksquare(A)
490500 lda = max(1 , stride(A, 2 ))
@@ -509,10 +519,12 @@ for (heevd, heev, heevx, heevj, elty, relty) in
509519 end
510520 return W, V
511521 end
512- function heev!(A:: StridedROCMatrix{$elty} ,
513- W:: StridedROCVector{$relty} ,
514- V:: StridedROCMatrix{$elty} ;
515- uplo:: Char = ' U' )
522+ function heev!(
523+ A:: StridedROCMatrix{$elty} ,
524+ W:: StridedROCVector{$relty} ,
525+ V:: StridedROCMatrix{$elty} ;
526+ uplo:: Char = ' U'
527+ )
516528 chkuplo(uplo)
517529 n = checksquare(A)
518530 lda = max(1 , stride(A, 2 ))
@@ -537,11 +549,13 @@ for (heevd, heev, heevx, heevj, elty, relty) in
537549 end
538550 return W, V
539551 end
540- function heevx!(A:: StridedROCMatrix{$elty} ,
541- W:: StridedROCVector{$relty} ,
542- V:: StridedROCMatrix{$elty} ;
543- uplo:: Char = ' U' ,
544- kwargs... )
552+ function heevx!(
553+ A:: StridedROCMatrix{$elty} ,
554+ W:: StridedROCVector{$relty} ,
555+ V:: StridedROCMatrix{$elty} ;
556+ uplo:: Char = ' U' ,
557+ kwargs...
558+ )
545559 chkuplo(uplo)
546560 n = checksquare(A)
547561 lda = max(1 , stride(A, 2 ))
@@ -567,27 +581,29 @@ for (heevd, heev, heevx, heevj, elty, relty) in
567581 size(V) == (n, n) || throw(DimensionMismatch(" size mismatch between A and V" ))
568582 jobz = rocSOLVER. rocblas_evect_original
569583 end
570- dh = rocBLAS. handle()
571- abstol = - one($ relty)
572- nev = ROCVector{Cint}(undef, 1 )
573- ldv = max(1 , stride(V, 2 ))
574- ifail = ROCVector{Cint}(undef, n)
584+ dh = rocBLAS. handle()
585+ abstol = - one($ relty)
586+ nev = ROCVector{Cint}(undef, 1 )
587+ ldv = max(1 , stride(V, 2 ))
588+ ifail = ROCVector{Cint}(undef, n)
575589 dev_info = ROCVector{Cint}(undef, 1 )
576590 roc_uplo = convert(rocSOLVER. rocblas_fill, uplo)
577591 $ heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info)
578592
579593 info = @allowscalar dev_info[1 ]
580594 chkargsok(BlasInt(info))
581- m = @allowscalar nev[1 ]
595+ m = @allowscalar nev[1 ]
582596 return W, V, m
583597 end
584- function heevj!(A:: StridedROCMatrix{$elty} ,
585- W:: StridedROCVector{$relty} ,
586- V:: StridedROCMatrix{$elty} ;
587- uplo:: Char = ' U' ,
588- tol:: $relty = eps($ relty),
589- max_sweeps:: Int = 100 ,
590- sort:: Char = ' N' )
598+ function heevj!(
599+ A:: StridedROCMatrix{$elty} ,
600+ W:: StridedROCVector{$relty} ,
601+ V:: StridedROCMatrix{$elty} ;
602+ uplo:: Char = ' U' ,
603+ tol:: $relty = eps($ relty),
604+ max_sweeps:: Int = 100 ,
605+ sort:: Char = ' N'
606+ )
591607 chkuplo(uplo)
592608 n = checksquare(A)
593609 lda = max(1 , stride(A, 2 ))
0 commit comments