Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ passed as the third positional argument in the form of a `NamedTuple`.
""" select_algorithm

function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return select_algorithm(f, typeof(A), alg; kwargs...)
end
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
Expand Down Expand Up @@ -193,10 +190,24 @@ macro functiondef(f)
end

# define fallbacks for algorithm selection
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
kwargs...) where {Alg,A}
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
return select_algorithm($f!, A, alg; kwargs...)
end
# define default algorithm fallbacks for out-of-place functions
# in terms of the corresponding in-place function
@inline function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
# define default algorithm fallbacks for out-of-place functions
# in terms of the corresponding in-place function for types,
# in principle this is covered by the definition above but
# it is necessary to avoid ambiguity errors with the generic definitions:
# ```julia
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
# throw(MethodError(default_algorithm, (f, T)))
# end
# ```
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_algorithm($f!, A; kwargs...)
end
Expand Down
13 changes: 9 additions & 4 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ for f in (:eig_full!, :eig_vals!)
end
end

function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end
end
13 changes: 9 additions & 4 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,13 @@ for f in (:eigh_full!, :eigh_vals!)
end
end

function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end
end
6 changes: 2 additions & 4 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
end

for f in (:lq_full!, :lq_compact!, :lq_null!)
@eval begin
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_lq_algorithm(A; kwargs...)
end
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_lq_algorithm(A; kwargs...)
end
end
5 changes: 1 addition & 4 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ end
# Algorithm selection
# -------------------
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
function default_polar_algorithm(T::Type; kwargs...)
throw(MethodError(default_polar_algorithm, (T,)))
end
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
function default_polar_algorithm(::Type{T}; kwargs...) where {T}
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
end

Expand Down
7 changes: 2 additions & 5 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
end

for f in (:qr_full!, :qr_compact!, :qr_null!)
@eval begin
function default_algorithm(::typeof($f), ::Type{A};
kwargs...) where {A<:YALAPACK.BlasMat}
return default_qr_algorithm(A; kwargs...)
end
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_qr_algorithm(A; kwargs...)
end
end
13 changes: 9 additions & 4 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
end
end

function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing,
kwargs...) where {A<:YALAPACK.BlasMat}
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
end
end
8 changes: 7 additions & 1 deletion test/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm,
default_algorithm, select_algorithm
TruncationKeepBelow, default_algorithm, select_algorithm

@testset "default_algorithm" begin
A = randn(3, 3)
Expand Down Expand Up @@ -50,6 +50,12 @@ end
NoTruncation())
end

alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0))
for f in (eig_trunc!, eigh_trunc!, svd_trunc!)
@test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2))
end

@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) ===
LAPACK_DivideAndConquer()
Expand Down
16 changes: 15 additions & 1 deletion test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using TestExtras
using StableRNGs
using LinearAlgebra: Diagonal
using MatrixAlgebraKit: diagview
using MatrixAlgebraKit: TruncatedAlgorithm, diagview

@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -57,3 +57,17 @@ end
@test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
end
end

@testset "eig_trunc! specify truncation algorithm T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
m = 4
V = randn(rng, T, m, m)
D = Diagonal([0.9, 0.3, 0.1, 0.01])
A = V * D * inv(V)
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
D2, V2 = @constinferred eig_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
@test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2))
end
18 changes: 17 additions & 1 deletion test/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I
using MatrixAlgebraKit: diagview
using MatrixAlgebraKit: TruncatedAlgorithm, diagview

@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -62,3 +62,19 @@ end
@test V2 * (V2' * V1) ≈ V1
end
end

@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
(Float32, Float64,
ComplexF32,
ComplexF64)
rng = StableRNG(123)
m = 4
V = qr_compact(randn(rng, T, m, m))[1]
D = Diagonal([0.9, 0.3, 0.1, 0.01])
A = V * D * V'
A = (A + A') / 2
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2))
D2, V2 = @constinferred eigh_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
end
17 changes: 16 additions & 1 deletion test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
using MatrixAlgebraKit: TruncationKeepAbove, diagview
using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview

@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
Expand Down Expand Up @@ -152,3 +152,18 @@ end
end
end
end

@testset "svd_trunc! specify truncation algorithm T = $T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
rng = StableRNG(123)
m = 4
U = qr_compact(randn(rng, T, m, m))[1]
S = Diagonal([0.9, 0.3, 0.1, 0.01])
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
A = U * S * Vᴴ
alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), TruncationKeepAbove(0.2, 0.0))
U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg)
@test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T)))
@test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2))
end
Loading