Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/dev_interface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
```@meta
CurrentModule = MatrixAlgebraKit
CollapsedDocStrings = true
```

# Developer Interface

MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms.

```@docs; canonical=false
MatrixAlgebraKit.default_algorithm
MatrixAlgebraKit.select_algorithm
```
3 changes: 3 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

VERSION >= v"1.11.0-DEV.469" &&
eval(Expr(:public, :default_algorithm, :select_algorithm))

include("common/defaults.jl")
include("common/initialization.jl")
include("common/pullbacks.jl")
Expand Down
64 changes: 55 additions & 9 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,64 @@
end

@doc """
select_algorithm(f, A; kwargs...)
MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm)
MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...)
MatrixAlgebraKit.select_algorithm(f, A, alg::Type; kwargs...)
MatrixAlgebraKit.select_algorithm(f, A; kwargs...)
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))

Given some keyword arguments and an input `A`, decide on an algrithm to use for
implementing the function `f` on inputs of type `A`.
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.

If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.

If `alg` is a `Symbol` or a `Type` of algorithm, the return value is obtained
by calling the corresponding algorithm constructor;
keyword arguments in `kwargs` are passed along to this constructor.

If `alg` is not specified (or `nothing`), an algorithm will be selected
automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
"""
function select_algorithm end

function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return _select_algorithm(f, A, alg; kwargs...)
end

function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
return default_algorithm(f, A; kwargs...)
end
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
return Algorithm{alg}(; kwargs...)
end
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
return Alg(; kwargs...)
end
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
end
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
return alg
end
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
return select_algorithm(f, A; alg...)
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
return throw(ArgumentError("Unknown alg $alg"))

Check warning on line 103 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L102-L103

Added lines #L102 - L103 were not covered by tests
end

@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)

Select the default algorithm for a given factorization function `f` and input `A`.
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
explicitly.
"""
function default_algorithm end

@doc """
copy_input(f, A)

Expand Down Expand Up @@ -138,9 +182,11 @@
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
$f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...))
function $f!(A, out; kwargs...)
return $f!(A, out, select_algorithm($f!, A; kwargs...))
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
Expand Down
32 changes: 16 additions & 16 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
end
if kind == :qr
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
return qr_compact!(A, VC, alg_qr′)
elseif kind == :polar
size(A, 1) >= size(A, 2) ||
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
return left_polar!(A, VC, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
return U, lmul!(S, Vᴴ)
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
Expand All @@ -122,22 +122,22 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
end
if kind == :lq
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
return lq_compact!(A, CVᴴ, alg_lq′)
elseif kind == :polar
size(A, 2) >= size(A, 1) ||
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
return right_polar!(A, CVᴴ, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
return rmul!(U, S), Vᴴ
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
Expand Down Expand Up @@ -167,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing,
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
end
if kind == :qr
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
return qr_null!(A, N, alg_qr′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, _, _ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(N, view(U, 1:m, (n + 1):m))
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, S, _ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
Expand All @@ -194,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
end
if kind == :lq
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
return lq_null!(A, Nᴴ, alg_lq′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, _, Vᴴ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, S, Vᴴ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
Expand Down
17 changes: 15 additions & 2 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
"""
struct NoTruncation <: TruncationStrategy end

function select_truncation(trunc)
if isnothing(trunc)
return NoTruncation()
elseif trunc isa NamedTuple
return TruncationStrategy(; trunc...)
elseif trunc isa TruncationStrategy
return trunc
else
return throw(ArgumentError("Unknown truncation strategy: $trunc"))

Check warning on line 43 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L43

Added line #L43 was not covered by tests
end
end

# TODO: how do we deal with sorting/filters that treat zeros differently
# since these are implicitly discarded by selecting compact/full

Expand Down Expand Up @@ -98,8 +110,9 @@
TruncationStrategy
components::T
end
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) =
TruncationIntersection((trunc, truncs...))
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
return TruncationIntersection((trunc, truncs...))

Check warning on line 114 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L113-L114

Added lines #L113 - L114 were not covered by tests
end

function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1, trunc2))
Expand Down
29 changes: 9 additions & 20 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
for f in (:eig_full, :eig_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eig_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eig_trunc), A; kwargs...)
return select_algorithm(eig_trunc!, A; kwargs...)
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
return select_algorithm(eig_trunc!, A, alg; kwargs...)
end
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eig, alg_trunc)
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end

# Default to LAPACK
Expand Down
29 changes: 9 additions & 20 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
for f in (:eigh_full, :eigh_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eigh_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
return select_algorithm(eigh_trunc!, A; kwargs...)
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
end
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eigh, alg_trunc)
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
end

# Default to LAPACK
Expand Down
15 changes: 4 additions & 11 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
for f in (:lq_full, :lq_compact, :lq_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_lq_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_lq_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ end
for f in (:left_polar, :right_polar)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_polar_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_polar_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
for f in (:qr_full, :qr_compact, :qr_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_qr_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_qr_algorithm(A; kwargs...)
end
end
end
Expand Down
Loading
Loading