Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/dev_interface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
```@meta
CurrentModule = MatrixAlgebraKit
CollapsedDocStrings = true
```

# Developer Interface

MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms.

```@docs; canonical=false
MatrixAlgebraKit.default_algorithm
MatrixAlgebraKit.select_algorithm
```
3 changes: 3 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

VERSION >= v"1.11.0-DEV.469" &&
eval(Meta.parse("public default_algorithm, select_algorithm"))

include("common/defaults.jl")
include("common/initialization.jl")
include("common/pullbacks.jl")
Expand Down
60 changes: 51 additions & 9 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,60 @@
end

@doc """
select_algorithm(f, A; kwargs...)
MatrixAlgebraKit.select_algorithm(f, A, alg=nothing; kwargs...)
Given some keyword arguments and an input `A`, decide on an algrithm to use for
implementing the function `f` on inputs of type `A`.
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
If `alg` is `nothing` (the default value), an algorithm will be selected automatically
with [`MatrixAlgebraKit.default_algorithm`](@ref) and the keyword arguments will be passed
to the algorithm constructor.
If `alg` is a `NamedTuple`, an algorithm will be selected automatically
with [`default_algorithm`](@ref) and `alg` will be passed to the algorithm
as keyword arguments. In that case, keyword arguments can't be passed
to `MatrixAlgebraKit.select_algorithm`
If `alg` is an `AbstractAlgorithm`, it will be returned as-is. In that case, keyword arguments
can't be passed to `MatrixAlgebraKit.select_algorithm`.
"""
function select_algorithm end

function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
return _select_algorithm(f, A, alg; kwargs...)
end

function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
return default_algorithm(f, A; kwargs...)
end
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
return Algorithm{alg}(; kwargs...)
end
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
return Alg(; kwargs...)
end
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
return default_algorithm(f, A; alg...)
end
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
isempty(kwargs) ||
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
return alg
end
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
return select_algorithm(f, A; alg...)
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
return throw(ArgumentError("Unknown alg $alg"))

Check warning on line 99 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
end

@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
Select the default algorithm for a given factorization function `f` and input `A`.
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
explicitly.
"""
function default_algorithm end

@doc """
copy_input(f, A)
Expand Down Expand Up @@ -138,9 +178,11 @@
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)

# fill in arguments
$f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...))
function $f!(A, out; kwargs...)
return $f!(A, out, select_algorithm($f!, A; kwargs...))
function $f!(A; alg=nothing, kwargs...)
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, out; alg=nothing, kwargs...)
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
end
function $f!(A, alg::AbstractAlgorithm)
return $f!(A, initialize_output($f!, A, alg), alg)
Expand Down
32 changes: 16 additions & 16 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
end
if kind == :qr
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
return qr_compact!(A, VC, alg_qr′)
elseif kind == :polar
size(A, 1) >= size(A, 2) ||
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
return left_polar!(A, VC, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
return U, lmul!(S, Vᴴ)
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
Expand All @@ -122,22 +122,22 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
end
if kind == :lq
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
return lq_compact!(A, CVᴴ, alg_lq′)
elseif kind == :polar
size(A, 2) >= size(A, 1) ||
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
return right_polar!(A, CVᴴ, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
return rmul!(U, S), Vᴴ
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
Expand Down Expand Up @@ -167,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing,
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
end
if kind == :qr
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
return qr_null!(A, N, alg_qr′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, _, _ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(N, view(U, 1:m, (n + 1):m))
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, S, _ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
Expand All @@ -194,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
end
if kind == :lq
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
return lq_null!(A, Nᴴ, alg_lq′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, _, Vᴴ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
elseif kind == :svd
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, S, Vᴴ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
Expand Down
17 changes: 15 additions & 2 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
"""
struct NoTruncation <: TruncationStrategy end

function select_truncation(trunc)
if isnothing(trunc)
return NoTruncation()
elseif trunc isa NamedTuple
return TruncationStrategy(; trunc...)
elseif trunc isa TruncationStrategy
return trunc
else
return throw(ArgumentError("Unknown truncation strategy: $trunc"))

Check warning on line 43 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L43

Added line #L43 was not covered by tests
end
end

# TODO: how do we deal with sorting/filters that treat zeros differently
# since these are implicitly discarded by selecting compact/full

Expand Down Expand Up @@ -98,8 +110,9 @@
TruncationStrategy
components::T
end
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) =
TruncationIntersection((trunc, truncs...))
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
return TruncationIntersection((trunc, truncs...))

Check warning on line 114 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L113-L114

Added lines #L113 - L114 were not covered by tests
end

function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationIntersection((trunc1, trunc2))
Expand Down
29 changes: 9 additions & 20 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
for f in (:eig_full, :eig_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eig_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eig_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eig_trunc), A; kwargs...)
return select_algorithm(eig_trunc!, A; kwargs...)
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
return select_algorithm(eig_trunc!, A, alg; kwargs...)
end
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eig, alg_trunc)
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end

# Default to LAPACK
Expand Down
29 changes: 9 additions & 20 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
for f in (:eigh_full, :eigh_vals)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_eigh_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
end
end

function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
return select_algorithm(eigh_trunc!, A; kwargs...)
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
end
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
alg_trunc = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
isnothing(trunc) ? NoTruncation() :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return TruncatedAlgorithm(alg_eigh, alg_trunc)
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
end

# Default to LAPACK
Expand Down
15 changes: 4 additions & 11 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
for f in (:lq_full, :lq_compact, :lq_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_lq_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_lq_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ end
for f in (:left_polar, :right_polar)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_polar_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_polar_algorithm(A; kwargs...)
end
end
end
Expand Down
15 changes: 4 additions & 11 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
for f in (:qr_full, :qr_compact, :qr_null)
f! = Symbol(f, :!)
@eval begin
function select_algorithm(::typeof($f), A; kwargs...)
return select_algorithm($f!, A; kwargs...)
function default_algorithm(::typeof($f), A; kwargs...)
return default_algorithm($f!, A; kwargs...)
end
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
if alg isa AbstractAlgorithm
return alg
elseif alg isa Symbol
return Algorithm{alg}(; kwargs...)
else
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
return default_qr_algorithm(A; kwargs...)
end
function default_algorithm(::typeof($f!), A; kwargs...)
return default_qr_algorithm(A; kwargs...)
end
end
end
Expand Down
Loading
Loading