Skip to content

Commit c46119e

Browse files
mtfishmanJutho
andauthored
Refactor algorithm selection logic (#23)
Co-authored-by: Jutho <[email protected]>
1 parent 1b44845 commit c46119e

File tree

16 files changed

+226
-140
lines changed

16 files changed

+226
-140
lines changed

docs/src/dev_interface.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
```@meta
2+
CurrentModule = MatrixAlgebraKit
3+
CollapsedDocStrings = true
4+
```
5+
6+
# Developer Interface
7+
8+
MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms.
9+
10+
```@docs; canonical=false
11+
MatrixAlgebraKit.default_algorithm
12+
MatrixAlgebraKit.select_algorithm
13+
```

src/MatrixAlgebraKit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3030
LAPACK_DivideAndConquer, LAPACK_Jacobi
3131
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3232

33+
VERSION >= v"1.11.0-DEV.469" &&
34+
eval(Expr(:public, :default_algorithm, :select_algorithm))
35+
3336
include("common/defaults.jl")
3437
include("common/initialization.jl")
3538
include("common/pullbacks.jl")

src/algorithms.jl

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,64 @@ function _show_alg(io::IO, alg::Algorithm)
5454
end
5555

5656
@doc """
57-
select_algorithm(f, A; kwargs...)
57+
MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm)
58+
MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...)
59+
MatrixAlgebraKit.select_algorithm(f, A, alg::Type; kwargs...)
60+
MatrixAlgebraKit.select_algorithm(f, A; kwargs...)
61+
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))
5862
59-
Given some keyword arguments and an input `A`, decide on an algrithm to use for
60-
implementing the function `f` on inputs of type `A`.
63+
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
64+
65+
If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.
66+
67+
If `alg` is a `Symbol` or a `Type` of algorithm, the return value is obtained
68+
by calling the corresponding algorithm constructor;
69+
keyword arguments in `kwargs` are passed along to this constructor.
70+
71+
If `alg` is not specified (or `nothing`), an algorithm will be selected
72+
automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and
73+
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
74+
Finally, the same behavior is obtained when the keyword arguments are
75+
passed as the third positional argument in the form of a `NamedTuple`.
6176
"""
6277
function select_algorithm end
6378

64-
function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
79+
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80+
return _select_algorithm(f, A, alg; kwargs...)
81+
end
82+
83+
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
84+
return default_algorithm(f, A; kwargs...)
85+
end
86+
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
87+
return Algorithm{alg}(; kwargs...)
88+
end
89+
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
90+
return Alg(; kwargs...)
91+
end
92+
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
93+
isempty(kwargs) ||
94+
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
95+
return default_algorithm(f, A; alg...)
96+
end
97+
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
98+
isempty(kwargs) ||
99+
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
65100
return alg
66101
end
67-
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
68-
return select_algorithm(f, A; alg...)
102+
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
103+
return throw(ArgumentError("Unknown alg $alg"))
69104
end
70105

106+
@doc """
107+
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
108+
109+
Select the default algorithm for a given factorization function `f` and input `A`.
110+
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
111+
explicitly.
112+
"""
113+
function default_algorithm end
114+
71115
@doc """
72116
copy_input(f, A)
73117
@@ -138,9 +182,11 @@ macro functiondef(f)
138182
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
139183

140184
# fill in arguments
141-
$f!(A; kwargs...) = $f!(A, select_algorithm($f!, A; kwargs...))
142-
function $f!(A, out; kwargs...)
143-
return $f!(A, out, select_algorithm($f!, A; kwargs...))
185+
function $f!(A; alg=nothing, kwargs...)
186+
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
187+
end
188+
function $f!(A, out; alg=nothing, kwargs...)
189+
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
144190
end
145191
function $f!(A, alg::AbstractAlgorithm)
146192
return $f!(A, initialize_output($f!, A, alg), alg)

src/implementations/orthnull.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
8989
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
9090
end
9191
if kind == :qr
92-
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
92+
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
9393
return qr_compact!(A, VC, alg_qr′)
9494
elseif kind == :polar
9595
size(A, 1) >= size(A, 2) ||
9696
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
97-
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
97+
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
9898
return left_polar!(A, VC, alg_polar′)
9999
elseif kind == :svd && isnothing(trunc)
100-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
100+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
101101
V, C = VC
102102
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
103103
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
104104
return U, lmul!(S, Vᴴ)
105105
elseif kind == :svd
106-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
107-
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
106+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
107+
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
108108
V, C = VC
109109
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
110110
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
@@ -122,22 +122,22 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
122122
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
123123
end
124124
if kind == :lq
125-
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
125+
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
126126
return lq_compact!(A, CVᴴ, alg_lq′)
127127
elseif kind == :polar
128128
size(A, 2) >= size(A, 1) ||
129129
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
130-
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
130+
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
131131
return right_polar!(A, CVᴴ, alg_polar′)
132132
elseif kind == :svd && isnothing(trunc)
133-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
133+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
134134
C, Vᴴ = CVᴴ
135135
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
136136
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
137137
return rmul!(U, S), Vᴴ
138138
elseif kind == :svd
139-
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
140-
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
139+
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
140+
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
141141
C, Vᴴ = CVᴴ
142142
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
143143
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
@@ -167,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing,
167167
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
168168
end
169169
if kind == :qr
170-
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
170+
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
171171
return qr_null!(A, N, alg_qr′)
172172
elseif kind == :svd && isnothing(trunc)
173-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
173+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
174174
U, _, _ = svd_full!(A, alg_svd′)
175175
(m, n) = size(A)
176176
return copy!(N, view(U, 1:m, (n + 1):m))
177177
elseif kind == :svd
178-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
178+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
179179
U, S, _ = svd_full!(A, alg_svd′)
180180
trunc′ = trunc isa TruncationStrategy ? trunc :
181181
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
@@ -194,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
194194
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
195195
end
196196
if kind == :lq
197-
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
197+
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
198198
return lq_null!(A, Nᴴ, alg_lq′)
199199
elseif kind == :svd && isnothing(trunc)
200-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
200+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
201201
_, _, Vᴴ = svd_full!(A, alg_svd′)
202202
(m, n) = size(A)
203203
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
204204
elseif kind == :svd
205-
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
205+
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
206206
_, S, Vᴴ = svd_full!(A, alg_svd′)
207207
trunc′ = trunc isa TruncationStrategy ? trunc :
208208
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :

src/implementations/truncation.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes.
3232
"""
3333
struct NoTruncation <: TruncationStrategy end
3434

35+
function select_truncation(trunc)
36+
if isnothing(trunc)
37+
return NoTruncation()
38+
elseif trunc isa NamedTuple
39+
return TruncationStrategy(; trunc...)
40+
elseif trunc isa TruncationStrategy
41+
return trunc
42+
else
43+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
44+
end
45+
end
46+
3547
# TODO: how do we deal with sorting/filters that treat zeros differently
3648
# since these are implicitly discarded by selecting compact/full
3749

@@ -98,8 +110,9 @@ struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <:
98110
TruncationStrategy
99111
components::T
100112
end
101-
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) =
102-
TruncationIntersection((trunc, truncs...))
113+
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
114+
return TruncationIntersection((trunc, truncs...))
115+
end
103116

104117
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
105118
return TruncationIntersection((trunc1, trunc2))

src/interface/eig.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,21 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
9090
for f in (:eig_full, :eig_vals)
9191
f! = Symbol(f, :!)
9292
@eval begin
93-
function select_algorithm(::typeof($f), A; kwargs...)
94-
return select_algorithm($f!, A; kwargs...)
93+
function default_algorithm(::typeof($f), A; kwargs...)
94+
return default_algorithm($f!, A; kwargs...)
9595
end
96-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
97-
if alg isa AbstractAlgorithm
98-
return alg
99-
elseif alg isa Symbol
100-
return Algorithm{alg}(; kwargs...)
101-
else
102-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
103-
return default_eig_algorithm(A; kwargs...)
104-
end
96+
function default_algorithm(::typeof($f!), A; kwargs...)
97+
return default_eig_algorithm(A; kwargs...)
10598
end
10699
end
107100
end
108101

109-
function select_algorithm(::typeof(eig_trunc), A; kwargs...)
110-
return select_algorithm(eig_trunc!, A; kwargs...)
102+
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
103+
return select_algorithm(eig_trunc!, A, alg; kwargs...)
111104
end
112-
function select_algorithm(::typeof(eig_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
113-
alg_eig = select_algorithm(eig_full!, A; alg, kwargs...)
114-
alg_trunc = trunc isa TruncationStrategy ? trunc :
115-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
116-
isnothing(trunc) ? NoTruncation() :
117-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
118-
return TruncatedAlgorithm(alg_eig, alg_trunc)
105+
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
106+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
107+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
119108
end
120109

121110
# Default to LAPACK

src/interface/eigh.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,21 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
8989
for f in (:eigh_full, :eigh_vals)
9090
f! = Symbol(f, :!)
9191
@eval begin
92-
function select_algorithm(::typeof($f), A; kwargs...)
93-
return select_algorithm($f!, A; kwargs...)
92+
function default_algorithm(::typeof($f), A; kwargs...)
93+
return default_algorithm($f!, A; kwargs...)
9494
end
95-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
96-
if alg isa AbstractAlgorithm
97-
return alg
98-
elseif alg isa Symbol
99-
return Algorithm{alg}(; kwargs...)
100-
else
101-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
102-
return default_eigh_algorithm(A; kwargs...)
103-
end
95+
function default_algorithm(::typeof($f!), A; kwargs...)
96+
return default_eigh_algorithm(A; kwargs...)
10497
end
10598
end
10699
end
107100

108-
function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
109-
return select_algorithm(eigh_trunc!, A; kwargs...)
101+
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
102+
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
110103
end
111-
function select_algorithm(::typeof(eigh_trunc!), A; alg=nothing, trunc=nothing, kwargs...)
112-
alg_eigh = select_algorithm(eigh_full!, A; alg, kwargs...)
113-
alg_trunc = trunc isa TruncationStrategy ? trunc :
114-
trunc isa NamedTuple ? TruncationStrategy(; trunc...) :
115-
isnothing(trunc) ? NoTruncation() :
116-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
117-
return TruncatedAlgorithm(alg_eigh, alg_trunc)
104+
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
105+
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106+
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
118107
end
119108

120109
# Default to LAPACK

src/interface/lq.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
7171
for f in (:lq_full, :lq_compact, :lq_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_lq_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_lq_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

src/interface/polar.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,11 @@ end
6363
for f in (:left_polar, :right_polar)
6464
f! = Symbol(f, :!)
6565
@eval begin
66-
function select_algorithm(::typeof($f), A; kwargs...)
67-
return select_algorithm($f!, A; kwargs...)
66+
function default_algorithm(::typeof($f), A; kwargs...)
67+
return default_algorithm($f!, A; kwargs...)
6868
end
69-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
70-
if alg isa AbstractAlgorithm
71-
return alg
72-
elseif alg isa Symbol
73-
return Algorithm{alg}(; kwargs...)
74-
else
75-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
76-
return default_polar_algorithm(A; kwargs...)
77-
end
69+
function default_algorithm(::typeof($f!), A; kwargs...)
70+
return default_polar_algorithm(A; kwargs...)
7871
end
7972
end
8073
end

src/interface/qr.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
7171
for f in (:qr_full, :qr_compact, :qr_null)
7272
f! = Symbol(f, :!)
7373
@eval begin
74-
function select_algorithm(::typeof($f), A; kwargs...)
75-
return select_algorithm($f!, A; kwargs...)
74+
function default_algorithm(::typeof($f), A; kwargs...)
75+
return default_algorithm($f!, A; kwargs...)
7676
end
77-
function select_algorithm(::typeof($f!), A; alg=nothing, kwargs...)
78-
if alg isa AbstractAlgorithm
79-
return alg
80-
elseif alg isa Symbol
81-
return Algorithm{alg}(; kwargs...)
82-
else
83-
isnothing(alg) || throw(ArgumentError("Unknown alg $alg"))
84-
return default_qr_algorithm(A; kwargs...)
85-
end
77+
function default_algorithm(::typeof($f!), A; kwargs...)
78+
return default_qr_algorithm(A; kwargs...)
8679
end
8780
end
8881
end

0 commit comments

Comments
 (0)