Skip to content

Commit cf57841

Browse files
kshyattKatharine HyattJutho
authored
Add optional epsilon vector and keyword arg for error (#113)
* Add optional epsilon vector and keyword arg for error * Make NaN the same type * GLA tests should not be run on GPU * Don't use NaN because of BigFloat * Fix type signature * Update src/implementations/svd.jl Co-authored-by: Jutho <[email protected]> * Update src/implementations/svd.jl Co-authored-by: Jutho <[email protected]> * Fix undefined err * Use broadcasting to set epsilon * Force memory transfer * Comments * New norm calculation * Move to tuples of Nothing * Comment --------- Co-authored-by: Katharine Hyatt <[email protected]> Co-authored-by: Jutho <[email protected]>
1 parent a8a56cf commit cf57841

File tree

4 files changed

+48
-33
lines changed

4 files changed

+48
-33
lines changed

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
99
return GLA_QRIteration()
1010
end
1111

12-
for f! in (:svd_compact!, :svd_full!, :svd_vals!)
13-
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
12+
for f! in (:svd_compact!, :svd_full!)
13+
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
1414
end
15+
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
1516

1617
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
1718
F = svd!(A)
@@ -43,9 +44,8 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T
4344
return GLA_QRIteration(; kwargs...)
4445
end
4546

46-
for f! in (:eigh_full!, :eigh_vals!)
47-
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
48-
end
47+
MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
48+
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
4949

5050
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
5151
eigval, eigvec = eigen!(Hermitian(A); sortby = real)

ext/MatrixAlgebraKitGenericSchurExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <
99
return GS_QRIteration(; kwargs...)
1010
end
1111

12-
for f! in (:eig_full!, :eig_vals!)
13-
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing
14-
end
12+
MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing)
13+
MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing
1514

1615
function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
1716
D, V = GenericSchur.eigen!(A)

src/implementations/svd.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,20 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
206206
return S
207207
end
208208

209-
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
210-
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
209+
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
210+
ϵ = similar(A, real(eltype(A)), compute_error)
211+
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
212+
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
213+
end
214+
215+
function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
216+
U, S, Vᴴ, ϵ = USVᴴϵ
217+
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
211218
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
212-
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
219+
if !isempty(ϵ)
220+
ϵ .= truncation_error!(diagview(S), ind)
221+
end
222+
return USVᴴtrunc..., ϵ
213223
end
214224

215225
# Diagonal logic
@@ -362,16 +372,22 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
362372
return USVᴴ
363373
end
364374

365-
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
366-
check_input(svd_trunc!, A, USVᴴ, alg.alg)
367-
U, S, Vᴴ = USVᴴ
375+
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
376+
U, S, Vᴴ, ϵ = USVᴴϵ
377+
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
368378
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
369379

370380
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
371381
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
372382

373-
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
374-
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
383+
if !isempty(ϵ)
384+
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
385+
normS = norm(diagview(Str))
386+
normA = norm(A)
387+
# equivalent to sqrt(normA^2 - normS^2)
388+
# but may be more accurate
389+
ϵ = sqrt((normA + normS) * (normA - normS))
390+
end
375391

376392
do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
377393
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)

test/runtests.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ if !is_buildkite
5656
JET.test_package(MatrixAlgebraKit; target_defined_modules = true)
5757
end
5858
end
59+
60+
using GenericLinearAlgebra
61+
@safetestset "QR / LQ Decomposition" begin
62+
include("genericlinearalgebra/qr.jl")
63+
include("genericlinearalgebra/lq.jl")
64+
end
65+
@safetestset "Singular Value Decomposition" begin
66+
include("genericlinearalgebra/svd.jl")
67+
end
68+
@safetestset "Hermitian Eigenvalue Decomposition" begin
69+
include("genericlinearalgebra/eigh.jl")
70+
end
71+
72+
using GenericSchur
73+
@safetestset "General Eigenvalue Decomposition" begin
74+
include("genericschur/eig.jl")
75+
end
5976
end
6077

6178
using CUDA
@@ -110,20 +127,3 @@ if AMDGPU.functional()
110127
include("amd/orthnull.jl")
111128
end
112129
end
113-
114-
using GenericLinearAlgebra
115-
@safetestset "QR / LQ Decomposition" begin
116-
include("genericlinearalgebra/qr.jl")
117-
include("genericlinearalgebra/lq.jl")
118-
end
119-
@safetestset "Singular Value Decomposition" begin
120-
include("genericlinearalgebra/svd.jl")
121-
end
122-
@safetestset "Hermitian Eigenvalue Decomposition" begin
123-
include("genericlinearalgebra/eigh.jl")
124-
end
125-
126-
using GenericSchur
127-
@safetestset "General Eigenvalue Decomposition" begin
128-
include("genericschur/eig.jl")
129-
end

0 commit comments

Comments
 (0)