Skip to content
42 changes: 35 additions & 7 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,36 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
return S
end

function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
# nothing case here to handle GenericLinearAlgebra
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(S), ind)
end
return USVᴴtrunc..., ϵ
end
function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) where {Tϵ}
USVᴴ, ϵ = USVᴴϵ
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(S), ind)
end
return USVᴴtrunc..., ϵ
end

function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
ϵ = similar(USVᴴ[2], compute_error)
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(eltype(ϵ)))
end
function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true)
Tr = real(eltype(A))
ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0)
U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg)
return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr))
end

# Diagonal logic
Expand Down Expand Up @@ -362,16 +388,18 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
return USVᴴ
end

function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
check_input(svd_trunc!, A, USVᴴ, alg.alg)
U, S, Vᴴ = USVᴴ
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)

# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)

# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
if !isempty(ϵ)
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
end

do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
Expand Down
34 changes: 17 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ if !is_buildkite
JET.test_package(MatrixAlgebraKit; target_defined_modules = true)
end
end

using GenericLinearAlgebra
@safetestset "QR / LQ Decomposition" begin
include("genericlinearalgebra/qr.jl")
include("genericlinearalgebra/lq.jl")
end
@safetestset "Singular Value Decomposition" begin
include("genericlinearalgebra/svd.jl")
end
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end
end

using CUDA
Expand Down Expand Up @@ -110,20 +127,3 @@ if AMDGPU.functional()
include("amd/orthnull.jl")
end
end

using GenericLinearAlgebra
@safetestset "QR / LQ Decomposition" begin
include("genericlinearalgebra/qr.jl")
include("genericlinearalgebra/lq.jl")
end
@safetestset "Singular Value Decomposition" begin
include("genericlinearalgebra/svd.jl")
end
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end
Loading