diff --git a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl index ad1df695..029bc018 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl @@ -164,7 +164,7 @@ for (fname, elty, relty) in AMDGPU.unsafe_free!(dev_residual) AMDGPU.unsafe_free!(dev_n_sweeps) - return U, S, Vᴴ + return (S, U, Vᴴ) end end end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 9fbc2c4d..d51560e4 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm -import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj! using CUDA using LinearAlgebra using LinearAlgebra: BlasFloat @@ -30,6 +30,7 @@ _gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = YACUSOLVER.ungqr!(A, τ) _gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = YACUSOLVER.unmqr!(side, trans, A, τ, C) _gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ) _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) +_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...) _gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) end diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 03111f3c..d1838536 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -242,11 +242,70 @@ for (bname, fname, elty, relty) in if jobz == 'V' adjoint!(Vᴴ, Ṽ) end - return U, S, Vᴴ + return S, U, Vᴴ end end end +# Wrapper for randomized SVD +function Xgesvdr!(A::StridedCuMatrix{T}, + S::StridedCuVector=similar(A, real(T), min(size(A)...)), + U::StridedCuMatrix{T}=similar(A, T, size(A, 1), min(size(A)...)), + Vᴴ::StridedCuMatrix{T}=similar(A, T, min(size(A)...), size(A, 2)); + k::Int=length(S), + p::Int=min(size(A)...)-k-1, + niters::Int=1) where {T<:BlasFloat} + chkstride1(A, U, S, Vᴴ) + m, n = size(A) + minmn = min(m, n) + jobu = length(U) == 0 ? 'N' : 'S' + jobv = length(Vᴴ) == 0 ? 'N' : 'S' + R = eltype(S) + k < minmn || throw(DimensionMismatch("length of S ($k) must be less than the smaller dimension of A ($minmn)")) + k + p < minmn || throw(DimensionMismatch("length of S ($k) plus oversampling ($p) must be less than the smaller dimension of A ($minmn)")) + R == real(T) || + throw(ArgumentError("S does not have the matching real `eltype` of A")) + + Ṽ = similar(Vᴴ, (n, n)) + Ũ = (size(U) == (m, m)) ? U : similar(U, (m, m)) + lda = max(1, stride(A, 2)) + ldu = max(1, stride(Ũ, 2)) + ldv = max(1, stride(Ṽ, 2)) + params = CUSOLVER.CuSolverParameters() + dh = CUSOLVER.dense_handle() + + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + CUSOLVER.cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters, + T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + T, out_gpu, out_cpu) + + return out_gpu[], out_cpu[] + end + CUSOLVER.with_workspaces(dh.workspace_gpu, dh.workspace_cpu, + bufferSize()...) do buffer_gpu, buffer_cpu + return CUSOLVER.cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters, + T, A, lda, R, S, T, Ũ, ldu, T, Ṽ, ldv, + T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), + dh.info) + end + + flag = @allowscalar dh.info[1] + CUSOLVER.chklapackerror(BlasInt(flag)) + if Ũ !== U && length(U) > 0 + U .= view(Ũ, 1:m, 1:size(U, 2)) + end + if length(Vᴴ) > 0 + Vᴴ .= view(Ṽ', 1:size(Vᴴ, 1), 1:n) + end + Ũ !== U && CUDA.unsafe_free!(Ũ) + CUDA.unsafe_free!(Ṽ) + + return S, U, Vᴴ +end + # for (jname, bname, fname, elty, relty) in # ((:sygvd!, :cusolverDnSsygvd_bufferSize, :cusolverDnSsygvd, :Float32, :Float32), # (:sygvd!, :cusolverDnDsygvd_bufferSize, :cusolverDnDsygvd, :Float64, :Float64), diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 9e3685bb..d074a16e 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,7 +33,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LQViaTransposedQR, - CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, + CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered diff --git a/src/common/gauge.jl b/src/common/gauge.jl index 0f64e0a2..a9bf985b 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -1,7 +1,7 @@ function gaugefix!(V::AbstractMatrix) for j in axes(V, 2) v = view(V, :, j) - s = conj(sign(argmax(abs, v))) + s = conj(sign(_argmaxabs(v))) @inbounds v .*= s end return V diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 84e97c4f..51b29145 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -8,7 +8,7 @@ function copy_input(::typeof(eig_vals), A::AbstractMatrix) end copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A) -function check_input(::typeof(eig_full!), A::AbstractMatrix, DV) +function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) D, V = DV @@ -19,7 +19,7 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV) @check_scalar(V, A, complex) return nothing end -function check_input(::typeof(eig_vals!), A::AbstractMatrix, D) +function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) @assert D isa AbstractVector @@ -51,7 +51,7 @@ end # -------------- # actual implementation function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) - check_input(eig_full!, A, DV) + check_input(eig_full!, A, DV, alg) D, V = DV if alg isa LAPACK_Simple isempty(alg.kwargs) || @@ -66,7 +66,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) end function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) - check_input(eig_vals!, A, D) + check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) if alg isa LAPACK_Simple isempty(alg.kwargs) || diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 5178748e..1e6a47f4 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -8,7 +8,7 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix) end copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A) -function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV) +function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) D, V = DV @@ -19,7 +19,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV) @check_scalar(V, A) return nothing end -function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D) +function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm) m, n = size(A) @assert D isa AbstractVector @check_size(D, (n,)) @@ -48,7 +48,7 @@ end # Implementation # -------------- function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) - check_input(eigh_full!, A, DV) + check_input(eigh_full!, A, DV, alg) D, V = DV Dd = D.diag if alg isa LAPACK_MultipleRelativelyRobustRepresentations @@ -70,7 +70,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) end function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) - check_input(eigh_vals!, A, D) + check_input(eigh_vals!, A, D, alg) V = similar(A, (size(A, 1), 0)) if alg isa LAPACK_MultipleRelativelyRobustRepresentations YALAPACK.heevr!(A, D, V; alg.kwargs...) diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 859260fa..4369b8de 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -7,7 +7,7 @@ function copy_input(::typeof(gen_eig_vals), A::AbstractMatrix, B::AbstractMatrix return copy_input(gen_eig_full, A, B) end -function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV) +function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm) ma, na = size(A) mb, nb = size(B) ma == na || throw(DimensionMismatch("square input matrix A expected")) @@ -24,7 +24,7 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr @check_scalar(V, B, complex) return nothing end -function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W) +function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm) ma, na = size(A) mb, nb = size(B) ma == na || throw(DimensionMismatch("square input matrix A expected")) @@ -57,7 +57,7 @@ end # -------------- # actual implementation function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm) - check_input(gen_eig_full!, A, B, WV) + check_input(gen_eig_full!, A, B, WV, alg) W, V = WV if alg isa LAPACK_Simple isempty(alg.kwargs) || @@ -72,7 +72,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig end function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm) - check_input(gen_eig_vals!, A, B, W) + check_input(gen_eig_vals!, A, B, W, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) if alg isa LAPACK_Simple isempty(alg.kwargs) || diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index 6c81ca12..98617fbf 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -10,7 +10,7 @@ function copy_input(::typeof(lq_null), A::AbstractMatrix) return copy!(similar(A, float(eltype(A))), A) end -function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ) +function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm) m, n = size(A) L, Q = LQ @assert L isa AbstractMatrix && Q isa AbstractMatrix @@ -20,7 +20,7 @@ function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ) @check_scalar(Q, A) return nothing end -function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ) +function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) L, Q = LQ @@ -31,7 +31,7 @@ function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ) @check_scalar(Q, A) return nothing end -function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ) +function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) @assert Nᴴ isa AbstractMatrix @@ -66,36 +66,36 @@ end # -------------- # actual implementation function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) - check_input(lq_full!, A, LQ) + check_input(lq_full!, A, LQ, alg) L, Q = LQ _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) - check_input(lq_full!, A, LQ) + check_input(lq_full!, A, LQ, alg) L, Q = LQ lq_via_qr!(A, L, Q, alg.qr_alg) return L, Q end function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) - check_input(lq_compact!, A, LQ) + check_input(lq_compact!, A, LQ, alg) L, Q = LQ _lapack_lq!(A, L, Q; alg.kwargs...) return L, Q end function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) - check_input(lq_compact!, A, LQ) + check_input(lq_compact!, A, LQ, alg) L, Q = LQ lq_via_qr!(A, L, Q, alg.qr_alg) return L, Q end function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ) - check_input(lq_null!, A, Nᴴ) + check_input(lq_null!, A, Nᴴ, alg) _lapack_lq_null!(A, Nᴴ; alg.kwargs...) return Nᴴ end function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR) - check_input(lq_null!, A, Nᴴ) + check_input(lq_null!, A, Nᴴ, alg) lq_null_via_qr!(A, Nᴴ, alg.qr_alg) return Nᴴ end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index f53055db..c6dc6248 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -5,7 +5,7 @@ copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever nee copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else -function check_input(::typeof(left_orth!), A::AbstractMatrix, VC) +function check_input(::typeof(left_orth!), A::AbstractMatrix, VC, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) V, C = VC @@ -18,7 +18,7 @@ function check_input(::typeof(left_orth!), A::AbstractMatrix, VC) end return nothing end -function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ) +function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) C, Vᴴ = CVᴴ @@ -32,7 +32,7 @@ function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ) return nothing end -function check_input(::typeof(left_null!), A::AbstractMatrix, N) +function check_input(::typeof(left_null!), A::AbstractMatrix, N, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) @assert N isa AbstractMatrix @@ -40,7 +40,7 @@ function check_input(::typeof(left_null!), A::AbstractMatrix, N) @check_scalar(N, A) return nothing end -function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ) +function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) @assert Nᴴ isa AbstractMatrix @@ -84,7 +84,6 @@ end function left_orth!(A, VC; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) - check_input(left_orth!, A, VC) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end @@ -100,20 +99,24 @@ function left_orth!(A, VC; trunc=nothing, end function left_orth_qr!(A, VC, alg) alg′ = select_algorithm(qr_compact!, A, alg) + check_input(left_orth!, A, VC, alg′) return qr_compact!(A, VC, alg′) end function left_orth_polar!(A, VC, alg) alg′ = select_algorithm(left_polar!, A, alg) + check_input(left_orth!, A, VC, alg′) return left_polar!(A, VC, alg′) end function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(left_orth!, A, VC, alg′) U, S, Vᴴ = svd_compact!(A, alg′) V, C = VC return copy!(V, U), mul!(C, S, Vᴴ) end function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(left_orth!, A, VC, alg′) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg′)) U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′) @@ -121,6 +124,7 @@ function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing) end function left_orth_svd!(A, VC, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(left_orth!, A, VC, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) U, S, Vᴴ = svd_trunc!(A, alg_trunc) V, C = VC @@ -128,6 +132,7 @@ function left_orth_svd!(A, VC, alg, trunc) end function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(left_orth!, A, VC, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) @@ -138,7 +143,6 @@ end function right_orth!(A, CVᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) - check_input(right_orth!, A, CVᴴ) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end @@ -154,20 +158,24 @@ function right_orth!(A, CVᴴ; trunc=nothing, end function right_orth_lq!(A, CVᴴ, alg) alg′ = select_algorithm(lq_compact!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) return lq_compact!(A, CVᴴ, alg′) end function right_orth_polar!(A, CVᴴ, alg) alg′ = select_algorithm(right_polar!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) return right_polar!(A, CVᴴ, alg′) end function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) U, S, Vᴴ′ = svd_compact!(A, alg′) C, Vᴴ = CVᴴ return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) end function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg′)) U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′) @@ -175,6 +183,7 @@ function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing) end function right_orth_svd!(A, CVᴴ, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) C, Vᴴ = CVᴴ @@ -182,6 +191,7 @@ function right_orth_svd!(A, CVᴴ, alg, trunc) end function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) + check_input(right_orth!, A, CVᴴ, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) @@ -204,7 +214,6 @@ end function left_null!(A, N; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_svd=(;)) - check_input(left_null!, A, N) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end @@ -218,10 +227,12 @@ function left_null!(A, N; trunc=nothing, end function left_null_qr!(A, N, alg) alg′ = select_algorithm(qr_null!, A, alg) + check_input(left_null!, A, N, alg′) return qr_null!(A, N, alg′) end function left_null_svd!(A, N, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_full!, A, alg) + check_input(left_null!, A, N, alg′) U, _, _ = svd_full!(A, alg′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) @@ -238,7 +249,6 @@ end function right_null!(A, Nᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_svd=(;)) - check_input(right_null!, A, Nᴴ) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end @@ -252,16 +262,19 @@ function right_null!(A, Nᴴ; trunc=nothing, end function right_null_lq!(A, Nᴴ, alg) alg′ = select_algorithm(lq_null!, A, alg) + check_input(right_null!, A, Nᴴ, alg′) return lq_null!(A, Nᴴ, alg′) end function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_full!, A, alg) + check_input(right_null!, A, Nᴴ, alg′) _, _, Vᴴ = svd_full!(A, alg′) (m, n) = size(A) return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) end function right_null_svd!(A, Nᴴ, alg, trunc) alg′ = select_algorithm(svd_full!, A, alg) + check_input(right_null!, A, Nᴴ, alg′) _, S, Vᴴ = svd_full!(A, alg′) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : diff --git a/src/implementations/polar.jl b/src/implementations/polar.jl index 0e0cb553..1128ccda 100644 --- a/src/implementations/polar.jl +++ b/src/implementations/polar.jl @@ -3,7 +3,7 @@ copy_input(::typeof(left_polar), A) = copy_input(svd_full, A) copy_input(::typeof(right_polar), A) = copy_input(svd_full, A) -function check_input(::typeof(left_polar!), A::AbstractMatrix, WP) +function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlgorithm) m, n = size(A) W, P = WP m >= n || @@ -15,7 +15,7 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP) @check_scalar(P, A) return nothing end -function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ) +function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::AbstractAlgorithm) m, n = size(A) P, Wᴴ = PWᴴ n >= m || @@ -46,7 +46,7 @@ end # Implementation # -------------- function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) - check_input(left_polar!, A, WP) + check_input(left_polar!, A, WP, alg) U, S, Vᴴ = svd_compact!(A, alg.svdalg) W, P = WP W = mul!(W, U, Vᴴ) @@ -56,7 +56,7 @@ function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD) return (W, P) end function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD) - check_input(right_polar!, A, PWᴴ) + check_input(right_polar!, A, PWᴴ, alg) U, S, Vᴴ = svd_compact!(A, alg.svdalg) P, Wᴴ = PWᴴ Wᴴ = mul!(Wᴴ, U, Vᴴ) diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 5b45eff9..638d65d8 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -10,7 +10,7 @@ function copy_input(::typeof(qr_null), A::AbstractMatrix) return copy!(similar(A, float(eltype(A))), A) end -function check_input(::typeof(qr_full!), A::AbstractMatrix, QR) +function check_input(::typeof(qr_full!), A::AbstractMatrix, QR, ::AbstractAlgorithm) m, n = size(A) Q, R = QR @assert Q isa AbstractMatrix && R isa AbstractMatrix @@ -20,7 +20,7 @@ function check_input(::typeof(qr_full!), A::AbstractMatrix, QR) @check_scalar(R, A) return nothing end -function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR) +function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) Q, R = QR @@ -31,7 +31,7 @@ function check_input(::typeof(qr_compact!), A::AbstractMatrix, QR) @check_scalar(R, A) return nothing end -function check_input(::typeof(qr_null!), A::AbstractMatrix, N) +function check_input(::typeof(qr_null!), A::AbstractMatrix, N, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) @assert N isa AbstractMatrix @@ -66,19 +66,19 @@ end # -------------- # actual implementation function qr_full!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) - check_input(qr_full!, A, QR) + check_input(qr_full!, A, QR, alg) Q, R = QR _lapack_qr!(A, Q, R; alg.kwargs...) return Q, R end function qr_compact!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) - check_input(qr_compact!, A, QR) + check_input(qr_compact!, A, QR, alg) Q, R = QR _lapack_qr!(A, Q, R; alg.kwargs...) return Q, R end function qr_null!(A::AbstractMatrix, N, alg::LAPACK_HouseholderQR) - check_input(qr_null!, A, N) + check_input(qr_null!, A, N, alg) _lapack_qr_null!(A, N; alg.kwargs...) return N end @@ -172,19 +172,19 @@ end # CUDA and AMDGPU ### function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) - check_input(qr_full!, A, QR) + check_input(qr_full!, A, QR, alg) Q, R = QR _gpu_qr!(A, Q, R; alg.kwargs...) return Q, R end function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) - check_input(qr_compact!, A, QR) + check_input(qr_compact!, A, QR, alg) Q, R = QR _gpu_qr!(A, Q, R; alg.kwargs...) return Q, R end function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) - check_input(qr_null!, A, N) + check_input(qr_null!, A, N, alg) _gpu_qr_null!(A, N; alg.kwargs...) return N end diff --git a/src/implementations/schur.jl b/src/implementations/schur.jl index 541ae97f..1cc3ba5f 100644 --- a/src/implementations/schur.jl +++ b/src/implementations/schur.jl @@ -4,7 +4,7 @@ copy_input(::typeof(schur_full), A::AbstractMatrix) = copy_input(eig_full, A) copy_input(::typeof(schur_vals), A::AbstractMatrix) = copy_input(eig_vals, A) # check input -function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv) +function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) T, Z, vals = TZv @@ -17,7 +17,7 @@ function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv) @check_scalar(vals, A, complex) return nothing end -function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals) +function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected")) @assert vals isa AbstractVector @@ -43,7 +43,7 @@ end # Implementation # -------------- function schur_full!(A::AbstractMatrix, TZv, alg::LAPACK_EigAlgorithm) - check_input(schur_full!, A, TZv) + check_input(schur_full!, A, TZv, alg) T, Z, vals = TZv if alg isa LAPACK_Simple isempty(alg.kwargs) || @@ -59,7 +59,7 @@ function schur_full!(A::AbstractMatrix, TZv, alg::LAPACK_EigAlgorithm) end function schur_vals!(A::AbstractMatrix, vals, alg::LAPACK_EigAlgorithm) - check_input(schur_vals!, A, vals) + check_input(schur_vals!, A, vals, alg) Z = similar(A, eltype(A), (size(A, 1), 0)) if alg isa LAPACK_Simple isempty(alg.kwargs) || diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9de3777d..83e5f220 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -8,7 +8,7 @@ copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) # TODO: many of these checks are happening again in the LAPACK routines -function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ) +function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ, ::AbstractAlgorithm) m, n = size(A) U, S, Vᴴ = USVᴴ @assert U isa AbstractMatrix && S isa AbstractMatrix && Vᴴ isa AbstractMatrix @@ -20,7 +20,7 @@ function check_input(::typeof(svd_full!), A::AbstractMatrix, USVᴴ) @check_scalar(Vᴴ, A) return nothing end -function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ) +function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) U, S, Vᴴ = USVᴴ @@ -33,7 +33,7 @@ function check_input(::typeof(svd_compact!), A::AbstractMatrix, USVᴴ) @check_scalar(Vᴴ, A) return nothing end -function check_input(::typeof(svd_vals!), A::AbstractMatrix, S) +function check_input(::typeof(svd_vals!), A::AbstractMatrix, S, ::AbstractAlgorithm) m, n = size(A) minmn = min(m, n) @assert S isa AbstractVector @@ -66,10 +66,56 @@ function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::Truncat return initialize_output(svd_compact!, A, alg.alg) end +function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int) + for j in 1:max(m, n) + if j <= min(m, n) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + elseif j <= m + u = view(U, :, j) + s = conj(sign(_argmaxabs(u))) + u .*= s + else + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(v))) + v .*= s + end + end + return (U, S, Vᴴ) +end + +# Gauge fixation +# -------------- +function gaugefix!(::typeof(svd_compact!), U, S, Vᴴ, m::Int, n::Int) + for j in 1:size(U, 2) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return (U, S, Vᴴ) +end + +function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int) + for j in 1:min(m, n) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return (U, S, Vᴴ) +end + + # Implementation # -------------- function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ) + check_input(svd_full!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ fill!(S, zero(eltype(S))) m, n = size(A) @@ -100,28 +146,12 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) S[i, 1] = zero(eltype(S)) end # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:max(m, n) - if j <= minmn - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(argmax(abs, u))) - u .*= s - v .*= conj(s) - elseif j <= m - u = view(U, :, j) - s = conj(sign(argmax(abs, u))) - u .*= s - else - v = view(Vᴴ, j, :) - s = conj(sign(argmax(abs, v))) - v .*= s - end - end + gaugefix!(svd_full!, U, S, Vᴴ, m, n) return USVᴴ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ) + check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ if alg isa LAPACK_QRIteration isempty(alg.kwargs) || @@ -141,18 +171,12 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:size(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(argmax(abs, u))) - u .*= s - v .*= conj(s) - end + gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) return USVᴴ end function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) - check_input(svd_vals!, A, S) + check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa LAPACK_QRIteration isempty(alg.kwargs) || @@ -185,7 +209,8 @@ end ### const CUSOLVER_SVDAlgorithm = Union{CUSOLVER_QRIteration, CUSOLVER_SVDPolar, - CUSOLVER_Jacobi} + CUSOLVER_Jacobi, + CUSOLVER_Randomized} const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} @@ -193,14 +218,38 @@ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} +const GPU_Randomized = Union{CUSOLVER_Randomized} + +function check_input(::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized) + m, n = size(A) + minmn = min(m, n) + U, S, Vᴴ = USVᴴ + @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix + @check_size(U, (m, m)) + @check_scalar(U, A) + @check_size(S, (minmn, minmn)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (n, n)) + @check_scalar(Vᴴ, A) + return nothing +end + +function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized}) + m, n = size(A) + minmn = min(m, n) + U = similar(A, (m, m)) + S = Diagonal(similar(A, real(eltype(A)), (minmn,))) + Vᴴ = similar(A, (n, n)) + return (U, S, Vᴴ) +end _gpu_gesvd!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) = throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) _gpu_Xgesvdp!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) +_gpu_Xgesvdr!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) _gpu_gesvdj!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) - # GPU SVD implementation function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ) + check_input(svd_full!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ fill!(S, zero(eltype(S))) m, n = size(A) @@ -223,28 +272,21 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgor diagview(S) .= view(S, 1:minmn, 1) view(S, 2:minmn, 1) .= zero(eltype(S)) # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:max(m, n) - if j <= minmn - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - elseif j <= m - u = view(U, :, j) - s = conj(sign(_argmaxabs(u))) - u .*= s - else - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(v))) - v .*= s - end - end + gaugefix!(svd_full!, U, S, Vᴴ, m, n) return USVᴴ end +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) + check_input(svd_trunc!, A, USVᴴ, alg.alg) + U, S, Vᴴ = USVᴴ + _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) + # TODO: make this controllable using a `gaugefix` keyword argument + gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) + return truncate!(svd_trunc!, USVᴴ, alg.trunc) +end + function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ) + check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ if alg isa GPU_QRIteration isempty(alg.kwargs) || @@ -258,20 +300,14 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl throw(ArgumentError("Unsupported SVD algorithm")) end # TODO: make this controllable using a `gaugefix` keyword argument - for j in 1:size(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - end + gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) return USVᴴ end _argmaxabs(x) = reduce(_largest, x; init=zero(eltype(x))) _largest(x, y) = abs(x) < abs(y) ? y : x function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) - check_input(svd_vals!, A, S) + check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) if alg isa GPU_QRIteration isempty(alg.kwargs) || diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 9485c745..e11c6bfa 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -151,6 +151,19 @@ a general matrix using the Jacobi algorithm. """ @algdef CUSOLVER_Jacobi +""" + CUSOLVER_Randomized(; p, niters) + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix using the randomized SVD algorithm. + +!!! note + Randomized SVD cannot compute all singular values of the input matrix `A`, only the first `k` where + `k < min(m, n)`. The remainder are used for oversampling. See the [CUSOLVER documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgesvdr) + for more information. +""" +@algdef CUSOLVER_Randomized + # ========================= # ROCSOLVER ALGORITHMS # ========================= diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 913f77db..8765ff08 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -1,6 +1,6 @@ using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef +using LinearAlgebra: Diagonal, isposdef, opnorm using Test using TestExtras using StableRNGs @@ -76,45 +76,47 @@ end @test isapproxone(Vᴴ' * Vᴴ) @test all(isposdef, diagview(S)) - Sc = similar(A, real(T), min(m, n)) + minmn = min(m, n) + Sc = similar(A, real(T), minmn) Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) @test Sc === Sc2 @test CuArray(diagview(S)) ≈ Sc # CuArray is necessary because norm of CuArray view with non-unit step is broken end + @testset "algorithm $alg" for alg in algs + end end end -# @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) -# rng = StableRNG(123) -# m = 54 -# if LinearAlgebra.LAPACK.version() < v"3.12.0" -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) -# else -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), -# LAPACK_Jacobi()) -# end - -# @testset "size ($m, $n)" for n in (37, m, 63) -# @testset "algorithm $alg" for alg in algs -# n > m && alg isa LAPACK_Jacobi && continue # not supported -# A = randn(rng, T, m, n) -# S₀ = svd_vals(A) -# minmn = min(m, n) -# r = minmn - 2 +@testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63) + k = min(m, n) - 20 + p = min(m, n) - k - 1 + algs = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k=k, p=p, niters=100),) + @testset "algorithm $alg" for alg in algs + n > m && alg isa CUSOLVER_QRIteration && continue # not supported + hA = randn(rng, T, m, n) + S₀ = svd_vals(hA) + A = CuArray(hA) + minmn = min(m, n) + r = k -# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) -# @test length(S1.diag) == r -# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) + @test length(S1.diag) == r + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] -# s = 1 + sqrt(eps(real(T))) -# trunc2 = trunctol(s * S₀[r + 1]) + if !(alg isa CUSOLVER_Randomized) + s = 1 + sqrt(eps(real(T))) + trunc2 = trunctol(s * S₀[r + 1]) -# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) -# @test length(S2.diag) == r -# @test U1 ≈ U2 -# @test S1 ≈ S2 -# @test V1ᴴ ≈ V2ᴴ -# end -# end -# end + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(s * S₀[r + 1])) + @test length(S2.diag) == r + @test U1 ≈ U2 + @test parent(S1) ≈ parent(S2) + @test V1ᴴ ≈ V2ᴴ + end + end + end +end diff --git a/test/orthnull.jl b/test/orthnull.jl index cc8f2a19..c402e830 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -5,7 +5,7 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, I, mul! using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output + initialize_output, AbstractAlgorithm # Used to test non-AbstractMatrix codepaths. struct LinearMap{P<:AbstractMatrix} @@ -33,11 +33,11 @@ end function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap) return LinearMap.(initialize_output(right_orth!, parent(A))) end -function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC) - return check_input(left_orth!, parent(A), parent.(VC)) +function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC, alg::AbstractAlgorithm) + return check_input(left_orth!, parent(A), parent.(VC), alg) end -function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) - return check_input(right_orth!, parent(A), parent.(VC)) +function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC, alg::AbstractAlgorithm) + return check_input(right_orth!, parent(A), parent.(VC), alg) end function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A} return default_svd_algorithm(A; kwargs...)