diff --git a/src/algorithms.jl b/src/algorithms.jl index 6f9a4d49..af183103 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -77,9 +77,6 @@ passed as the third positional argument in the form of a `NamedTuple`. """ select_algorithm function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} - return select_algorithm(f, typeof(A), alg; kwargs...) -end -function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg} if isnothing(alg) return default_algorithm(f, A; kwargs...) elseif alg isa Symbol @@ -193,10 +190,24 @@ macro functiondef(f) end # define fallbacks for algorithm selection - @inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg; - kwargs...) where {Alg,A} + @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} return select_algorithm($f!, A, alg; kwargs...) end + # define default algorithm fallbacks for out-of-place functions + # in terms of the corresponding in-place function + @inline function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) + end + # define default algorithm fallbacks for out-of-place functions + # in terms of the corresponding in-place function for types, + # in principle this is covered by the definition above but + # it is necessary to avoid ambiguity errors with the generic definitions: + # ```julia + # default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) + # function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T} + # throw(MethodError(default_algorithm, (f, T))) + # end + # ``` @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_algorithm($f!, A; kwargs...) end diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 77aa0672..fd75193f 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -99,8 +99,13 @@ for f in (:eig_full!, :eig_vals!) end end -function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} - alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) +function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 3ed38789..a650ca44 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -100,8 +100,13 @@ for f in (:eigh_full!, :eigh_vals!) end end -function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} - alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) +function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 9de85bae..6f1ed12f 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -77,9 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:lq_full!, :lq_compact!, :lq_null!) - @eval begin - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_lq_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_lq_algorithm(A; kwargs...) end end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index 111a6e3a..87346ff2 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -61,10 +61,7 @@ end # Algorithm selection # ------------------- default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...) -function default_polar_algorithm(T::Type; kwargs...) - throw(MethodError(default_polar_algorithm, (T,))) -end -function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} +function default_polar_algorithm(::Type{T}; kwargs...) where {T} return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...)) end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index a1c12b7a..62d87080 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -77,10 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:qr_full!, :qr_compact!, :qr_null!) - @eval begin - function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return default_qr_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_qr_algorithm(A; kwargs...) end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index e6f9021a..fd4eb5a5 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -104,8 +104,13 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end end -function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} - alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) +function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + end end diff --git a/test/algorithms.jl b/test/algorithms.jl index 49524a89..9f9c4542 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm, - default_algorithm, select_algorithm + TruncationKeepBelow, default_algorithm, select_algorithm @testset "default_algorithm" begin A = randn(3, 3) @@ -50,6 +50,12 @@ end NoTruncation()) end + alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0)) + for f in (eig_trunc!, eigh_trunc!, svd_trunc!) + @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg + @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2)) + end + @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer() @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_DivideAndConquer() diff --git a/test/eig.jl b/test/eig.jl index cd73d94e..cdaec9dc 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: Diagonal -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncatedAlgorithm, diagview @testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -57,3 +57,17 @@ end @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 end end + +@testset "eig_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = randn(rng, T, m, m) + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * inv(V) + alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) + D2, V2 = @constinferred eig_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2)) +end diff --git a/test/eigh.jl b/test/eigh.jl index 5a3c5a8a..6e785158 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncatedAlgorithm, diagview @testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -62,3 +62,19 @@ end @test V2 * (V2' * V1) ≈ V1 end end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(randn(rng, T, m, m))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end diff --git a/test/svd.jl b/test/svd.jl index eb6a7805..40de0897 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef -using MatrixAlgebraKit: TruncationKeepAbove, diagview +using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -152,3 +152,18 @@ end end end end + +@testset "svd_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), TruncationKeepAbove(0.2, 0.0)) + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2)) +end