diff --git a/docs/src/dev_interface.md b/docs/src/dev_interface.md index e69de29b..4482a5c6 100644 --- a/docs/src/dev_interface.md +++ b/docs/src/dev_interface.md @@ -0,0 +1,13 @@ +```@meta +CurrentModule = MatrixAlgebraKit +CollapsedDocStrings = true +``` + +# Developer Interface + +MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms. + +```@docs; canonical=false +MatrixAlgebraKit.default_algorithm +MatrixAlgebraKit.select_algorithm +``` diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index c6820233..a9f48393 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -30,6 +30,9 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_DivideAndConquer, LAPACK_Jacobi export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered +VERSION >= v"1.11.0-DEV.469" && + eval(Expr(:public, :default_algorithm, :select_algorithm)) + include("common/defaults.jl") include("common/initialization.jl") include("common/pullbacks.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index 34a3f3d1..4a6bfe58 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -54,20 +54,64 @@ function _show_alg(io::IO, alg::Algorithm) end @doc """ - select_algorithm(f, A; kwargs...) + MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm) + MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...) + MatrixAlgebraKit.select_algorithm(f, A, alg::Type; kwargs...) + MatrixAlgebraKit.select_algorithm(f, A; kwargs...) + MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...)) -Given some keyword arguments and an input `A`, decide on an algrithm to use for -implementing the function `f` on inputs of type `A`. +Decide on an algorithm to use for implementing the function `f` on inputs of type `A`. + +If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is. + +If `alg` is a `Symbol` or a `Type` of algorithm, the return value is obtained +by calling the corresponding algorithm constructor; +keyword arguments in `kwargs` are passed along to this constructor. + +If `alg` is not specified (or `nothing`), an algorithm will be selected +automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and +the keyword arguments in `kwargs` will be passed to the algorithm constructor. +Finally, the same behavior is obtained when the keyword arguments are +passed as the third positional argument in the form of a `NamedTuple`. """ function select_algorithm end -function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm) +function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} + return _select_algorithm(f, A, alg; kwargs...) +end + +function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F} + return default_algorithm(f, A; kwargs...) +end +function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F} + return Algorithm{alg}(; kwargs...) +end +function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg} + return Alg(; kwargs...) +end +function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return default_algorithm(f, A; alg...) +end +function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified.")) return alg end -function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple) - return select_algorithm(f, A; alg...) +function _select_algorithm(f::F, A, alg; kwargs...) where {F} + return throw(ArgumentError("Unknown alg $alg")) end +@doc """ + MatrixAlgebraKit.default_algorithm(f, A; kwargs...) + +Select the default algorithm for a given factorization function `f` and input `A`. +In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified +explicitly. +""" +function default_algorithm end + @doc """ copy_input(f, A) @@ -138,9 +182,11 @@ macro functiondef(f) $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) # fill in arguments - $f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...)) - function $f!(A, out; kwargs...) - return $f!(A, out, select_algorithm($f!, A; kwargs...)) + function $f!(A; alg=nothing, kwargs...) + return $f!(A, select_algorithm($f!, A, alg; kwargs...)) + end + function $f!(A, out; alg=nothing, kwargs...) + return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) end function $f!(A, alg::AbstractAlgorithm) return $f!(A, initialize_output($f!, A, alg), alg) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 2680e7cc..47626e12 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -89,22 +89,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr) + alg_qr′ = select_algorithm(qr_compact!, A, alg_qr) return qr_compact!(A, VC, alg_qr′) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg_polar′ = _select_algorithm(left_polar!, A, alg_polar) + alg_polar′ = select_algorithm(left_polar!, A, alg_polar) return left_polar!(A, VC, alg_polar′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) + alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) @@ -122,22 +122,22 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq) + alg_lq′ = select_algorithm(lq_compact!, A, alg_lq) return lq_compact!(A, CVᴴ, alg_lq′) elseif kind == :polar size(A, 2) >= size(A, 1) || throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg_polar′ = _select_algorithm(right_polar!, A, alg_polar) + alg_polar′ = select_algorithm(right_polar!, A, alg_polar) return right_polar!(A, CVᴴ, alg_polar′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) return rmul!(U, S), Vᴴ elseif kind == :svd - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) + alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc) @@ -167,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg_qr′ = _select_algorithm(qr_null!, A, alg_qr) + alg_qr′ = select_algorithm(qr_null!, A, alg_qr) return qr_null!(A, N, alg_qr′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = select_algorithm(svd_full!, A, alg_svd) U, _, _ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = select_algorithm(svd_full!, A, alg_svd) U, S, _ = svd_full!(A, alg_svd′) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : @@ -194,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg_lq′ = _select_algorithm(lq_null!, A, alg_lq) + alg_lq′ = select_algorithm(lq_null!, A, alg_lq) return lq_null!(A, Nᴴ, alg_lq′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = select_algorithm(svd_full!, A, alg_svd) _, _, Vᴴ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = select_algorithm(svd_full!, A, alg_svd) _, S, Vᴴ = svd_full!(A, alg_svd′) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 95a5c002..1898a010 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -32,6 +32,18 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes. """ struct NoTruncation <: TruncationStrategy end +function select_truncation(trunc) + if isnothing(trunc) + return NoTruncation() + elseif trunc isa NamedTuple + return TruncationStrategy(; trunc...) + elseif trunc isa TruncationStrategy + return trunc + else + return throw(ArgumentError("Unknown truncation strategy: $trunc")) + end +end + # TODO: how do we deal with sorting/filters that treat zeros differently # since these are implicitly discarded by selecting compact/full @@ -98,8 +110,9 @@ struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy components::T end -TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) = - TruncationIntersection((trunc, truncs...)) +function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) + return TruncationIntersection((trunc, truncs...)) +end function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) return TruncationIntersection((trunc1, trunc2)) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index fae199a0..9071a657 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). for f in (:eig_full, :eig_vals) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_eig_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_eig_algorithm(A; kwargs...) end end end -function select_algorithm(::typeof(eig_trunc), A; kwargs...) - return select_algorithm(eig_trunc!, A; kwargs...) +function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...) + return select_algorithm(eig_trunc!, A, alg; kwargs...) end -function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...) - alg_eig = select_algorithm(eig_full!, A; alg, kwargs...) - alg_trunc = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? TruncationStrategy(; trunc...) : - isnothing(trunc) ? NoTruncation() : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return TruncatedAlgorithm(alg_eig, alg_trunc) +function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end # Default to LAPACK diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 42e883db..b092795c 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc) for f in (:eigh_full, :eigh_vals) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_eigh_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_eigh_algorithm(A; kwargs...) end end end -function select_algorithm(::typeof(eigh_trunc), A; kwargs...) - return select_algorithm(eigh_trunc!, A; kwargs...) +function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...) + return select_algorithm(eigh_trunc!, A, alg; kwargs...) end -function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...) - alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...) - alg_trunc = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? TruncationStrategy(; trunc...) : - isnothing(trunc) ? NoTruncation() : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return TruncatedAlgorithm(alg_eigh, alg_trunc) +function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) + alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end # Default to LAPACK diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 82ef86e1..e98223f1 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). for f in (:lq_full, :lq_compact, :lq_null) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_lq_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_lq_algorithm(A; kwargs...) end end end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index c5d47e46..b209a327 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -63,18 +63,11 @@ end for f in (:left_polar, :right_polar) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_polar_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_polar_algorithm(A; kwargs...) end end end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index b7af9554..cbded32d 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). for f in (:qr_full, :qr_compact, :qr_null) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_qr_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_qr_algorithm(A; kwargs...) end end end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index 52ed3e96..c49e6a2b 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -54,18 +54,11 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). for f in (:schur_full, :schur_vals) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_eig_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_eig_algorithm(A; kwargs...) end end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 09fb79fd..1c5d7e3a 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -93,32 +93,21 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an for f in (:svd_full, :svd_compact, :svd_vals) f! = Symbol(f, :!) @eval begin - function select_algorithm(::typeof($f), A; kwargs...) - return select_algorithm($f!, A; kwargs...) + function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) end - function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...) - if alg isa AbstractAlgorithm - return alg - elseif alg isa Symbol - return Algorithm{alg}(; kwargs...) - else - isnothing(alg) || throw(ArgumentError("Unknown alg $alg")) - return default_svd_algorithm(A; kwargs...) - end + function default_algorithm(::typeof($f!), A; kwargs...) + return default_svd_algorithm(A; kwargs...) end end end -function select_algorithm(::typeof(svd_trunc), A; kwargs...) - return select_algorithm(svd_trunc!, A; kwargs...) +function select_algorithm(::typeof(svd_trunc), A, alg; kwargs...) + return select_algorithm(svd_trunc!, A, alg; kwargs...) end -function select_algorithm(::typeof(svd_trunc!), A; alg=nothing, trunc=nothing, kwargs...) - alg_svd = select_algorithm(svd_compact!, A; alg, kwargs...) - alg_trunc = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? TruncationStrategy(; trunc...) : - isnothing(trunc) ? NoTruncation() : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return TruncatedAlgorithm(alg_svd, alg_trunc) +function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end # Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` diff --git a/test/algorithms.jl b/test/algorithms.jl new file mode 100644 index 00000000..49524a89 --- /dev/null +++ b/test/algorithms.jl @@ -0,0 +1,60 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm, + default_algorithm, select_algorithm + +@testset "default_algorithm" begin + A = randn(3, 3) + for f in (svd_compact!, svd_compact, svd_full!, svd_full) + @test @constinferred(default_algorithm(f, A)) === LAPACK_DivideAndConquer() + end + for f in (eig_full!, eig_full, eig_vals!, eig_vals) + @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + end + for f in (eigh_full!, eigh_full, eigh_vals!, eigh_vals) + @test @constinferred(default_algorithm(f, A)) === + LAPACK_MultipleRelativelyRobustRepresentations() + end + for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null) + @test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderLQ() + end + for f in (left_polar!, left_polar, right_polar!, right_polar) + @test @constinferred(default_algorithm(f, A)) == + PolarViaSVD(LAPACK_DivideAndConquer()) + end + for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) + @test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderQR() + end + for f in (schur_full!, schur_full, schur_vals!, schur_vals) + @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() + end + + @test @constinferred(default_algorithm(qr_compact!, A; blocksize=2)) === + LAPACK_HouseholderQR(; blocksize=2) +end + +@testset "select_algorithm" begin + A = randn(3, 3) + for f in (svd_trunc!, svd_trunc) + @test @constinferred(select_algorithm(f, A)) === + TruncatedAlgorithm(LAPACK_DivideAndConquer(), NoTruncation()) + end + for f in (eig_trunc!, eig_trunc) + @test @constinferred(select_algorithm(f, A)) === + TruncatedAlgorithm(LAPACK_Expert(), NoTruncation()) + end + for f in (eigh_trunc!, eigh_trunc) + @test @constinferred(select_algorithm(f, A)) === + TruncatedAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(), + NoTruncation()) + end + + @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === + LAPACK_DivideAndConquer() + for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) + @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === + LAPACK_QRIteration() + end +end diff --git a/test/eig.jl b/test/eig.jl index d0dc13e3..cd73d94e 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -8,21 +8,23 @@ using MatrixAlgebraKit: diagview @testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 - for alg in (LAPACK_Simple(), LAPACK_Expert()) + for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple) A = randn(rng, T, m, m) Tc = complex(T) - D, V = @constinferred eig_full(A; alg) + D, V = @constinferred eig_full(A; alg=($alg)) @test eltype(D) == eltype(V) == Tc @test A * V ≈ V * D + alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) + Ac = similar(A) - D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg) + D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) @test D2 === D @test V2 === V @test A * V ≈ V * D - Dc = @constinferred eig_vals(A, alg) + Dc = @constinferred eig_vals(A, alg′) @test eltype(Dc) == Tc @test D ≈ Diagonal(Dc) end diff --git a/test/runtests.jl b/test/runtests.jl index 7ad7c216..5f0d2991 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,8 @@ using SafeTestsets +@safetestset "Algorithms" begin + include("algorithms.jl") +end @safetestset "Truncate" begin include("truncate.jl") end diff --git a/test/svd.jl b/test/svd.jl index e40a69e9..eb6a7805 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -11,17 +11,23 @@ using MatrixAlgebraKit: TruncationKeepAbove, diagview @testset "size ($m, $n)" for n in (37, m, 63) k = min(m, n) if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), + LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer) else algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), - LAPACK_Jacobi()) + LAPACK_Jacobi(), LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer) end @testset "algorithm $alg" for alg in algs n > m && alg isa LAPACK_Jacobi && continue # not supported minmn = min(m, n) A = randn(rng, T, m, n) - U, S, Vᴴ = svd_compact(A; alg) + if VERSION < v"1.11" + # This is type unstable on older versions of Julia. + U, S, Vᴴ = svd_compact(A; alg) + else + U, S, Vᴴ = @constinferred svd_compact(A; alg=($alg)) + end @test U isa Matrix{T} && size(U) == (m, minmn) @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) @@ -32,7 +38,8 @@ using MatrixAlgebraKit: TruncationKeepAbove, diagview Ac = similar(A) Sc = similar(A, real(T), min(m, n)) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) + alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) + U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) @test U2 === U @test S2 === S @test V2ᴴ === Vᴴ @@ -41,7 +48,7 @@ using MatrixAlgebraKit: TruncationKeepAbove, diagview @test Vᴴ * Vᴴ' ≈ I @test isposdef(S) - Sd = svd_vals(A, alg) + Sd = @constinferred svd_vals(A, alg′) @test S ≈ Diagonal(Sd) end end