Skip to content

Commit 765a6a4

Browse files
authored
Refactor algorithm selection in type domain (#30)
1 parent 1e86aea commit 765a6a4

File tree

11 files changed

+139
-151
lines changed

11 files changed

+139
-151
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -36,4 +36,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
3636
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3737

3838
[targets]
39-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"]
39+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore",
40+
"ChainRulesTestUtils", "StableRNGs", "Zygote"]

src/algorithms.jl

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ end
6161
MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...))
6262
6363
Decide on an algorithm to use for implementing the function `f` on inputs of type `A`.
64+
This can be obtained both for values `A` or types `A`.
6465
6566
If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is.
6667
@@ -73,62 +74,62 @@ automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and
7374
the keyword arguments in `kwargs` will be passed to the algorithm constructor.
7475
Finally, the same behavior is obtained when the keyword arguments are
7576
passed as the third positional argument in the form of a `NamedTuple`.
76-
"""
77-
function select_algorithm end
77+
""" select_algorithm
7878

7979
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80-
return _select_algorithm(f, A, alg; kwargs...)
80+
return select_algorithm(f, typeof(A), alg; kwargs...)
8181
end
82+
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
83+
if isnothing(alg)
84+
return default_algorithm(f, A; kwargs...)
85+
elseif alg isa Symbol
86+
return Algorithm{alg}(; kwargs...)
87+
elseif alg isa Type
88+
return alg(; kwargs...)
89+
elseif alg isa NamedTuple
90+
isempty(kwargs) ||
91+
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
92+
return default_algorithm(f, A; alg...)
93+
elseif alg isa AbstractAlgorithm
94+
isempty(kwargs) ||
95+
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
96+
return alg
97+
end
8298

83-
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
84-
return default_algorithm(f, A; kwargs...)
85-
end
86-
function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
87-
return Algorithm{alg}(; kwargs...)
88-
end
89-
function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
90-
return Alg(; kwargs...)
91-
end
92-
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
93-
isempty(kwargs) ||
94-
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
95-
return default_algorithm(f, A; alg...)
96-
end
97-
function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F}
98-
isempty(kwargs) ||
99-
throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified."))
100-
return alg
101-
end
102-
function _select_algorithm(f::F, A, alg; kwargs...) where {F}
103-
return throw(ArgumentError("Unknown alg $alg"))
99+
throw(ArgumentError("Unknown alg $alg"))
104100
end
105101

102+
106103
@doc """
107104
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
105+
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}
108106
109107
Select the default algorithm for a given factorization function `f` and input `A`.
110108
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified
111109
explicitly.
112-
"""
113-
function default_algorithm end
110+
New types should prefer to register their default algorithms in the type domain.
111+
""" default_algorithm
112+
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
113+
# avoid infinite recursion:
114+
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
115+
throw(MethodError(default_algorithm, (f, T)))
116+
end
114117

115118
@doc """
116119
copy_input(f, A)
117120
118121
Preprocess the input `A` for a given function, such that it may be handled correctly later.
119122
This may include a copy whenever the implementation would destroy the original matrix,
120123
or a change of element type to something that is supported.
121-
"""
122-
function copy_input end
124+
""" copy_input
123125

124126
@doc """
125127
initialize_output(f, A, alg)
126128
127129
Whenever possible, allocate the destination for applying a given algorithm in-place.
128130
If this is not possible, for example when the output size is not known a priori or immutable,
129131
this function may return `nothing`.
130-
"""
131-
function initialize_output end
132+
""" initialize_output
132133

133134
# Utility macros
134135
# --------------
@@ -176,25 +177,35 @@ macro functiondef(f)
176177
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
177178
f! = Symbol(f, :!)
178179

179-
return esc(quote
180-
# out of place to inplace
181-
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
182-
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
183-
184-
# fill in arguments
185-
function $f!(A; alg=nothing, kwargs...)
186-
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
187-
end
188-
function $f!(A, out; alg=nothing, kwargs...)
189-
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
190-
end
191-
function $f!(A, alg::AbstractAlgorithm)
192-
return $f!(A, initialize_output($f!, A, alg), alg)
193-
end
194-
195-
# copy documentation to both functions
196-
Core.@__doc__ $f, $f!
197-
end)
180+
ex = quote
181+
# out of place to inplace
182+
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
183+
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
184+
185+
# fill in arguments
186+
function $f!(A; alg=nothing, kwargs...)
187+
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
188+
end
189+
function $f!(A, out; alg=nothing, kwargs...)
190+
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
191+
end
192+
function $f!(A, alg::AbstractAlgorithm)
193+
return $f!(A, initialize_output($f!, A, alg), alg)
194+
end
195+
196+
# define fallbacks for algorithm selection
197+
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
198+
kwargs...) where {Alg,A}
199+
return select_algorithm($f!, A, alg; kwargs...)
200+
end
201+
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
202+
return default_algorithm($f!, A; kwargs...)
203+
end
204+
205+
# copy documentation to both functions
206+
Core.@__doc__ $f, $f!
207+
end
208+
return esc(ex)
198209
end
199210

200211
"""

src/interface/eig.jl

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,27 +87,20 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
8787

8888
# Algorithm selection
8989
# -------------------
90-
for f in (:eig_full, :eig_vals)
91-
f! = Symbol(f, :!)
92-
@eval begin
93-
function default_algorithm(::typeof($f), A; kwargs...)
94-
return default_algorithm($f!, A; kwargs...)
95-
end
96-
function default_algorithm(::typeof($f!), A; kwargs...)
97-
return default_eig_algorithm(A; kwargs...)
98-
end
99-
end
90+
default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...)
91+
default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,)))
92+
function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
93+
return LAPACK_Expert(; kwargs...)
10094
end
10195

102-
function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...)
103-
return select_algorithm(eig_trunc!, A, alg; kwargs...)
96+
for f in (:eig_full!, :eig_vals!)
97+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
98+
return default_eig_algorithm(A; kwargs...)
99+
end
104100
end
105-
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
101+
102+
function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
103+
kwargs...) where {A<:YALAPACK.BlasMat}
106104
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
107105
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
108106
end
109-
110-
# Default to LAPACK
111-
function default_eig_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
112-
return LAPACK_Expert(; kwargs...)
113-
end

src/interface/eigh.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,27 +86,22 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc)
8686

8787
# Algorithm selection
8888
# -------------------
89-
for f in (:eigh_full, :eigh_vals)
90-
f! = Symbol(f, :!)
91-
@eval begin
92-
function default_algorithm(::typeof($f), A; kwargs...)
93-
return default_algorithm($f!, A; kwargs...)
94-
end
95-
function default_algorithm(::typeof($f!), A; kwargs...)
96-
return default_eigh_algorithm(A; kwargs...)
97-
end
98-
end
89+
default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs...)
90+
function default_eigh_algorithm(T::Type; kwargs...)
91+
throw(MethodError(default_eigh_algorithm, (T,)))
92+
end
93+
function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
94+
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
9995
end
10096

101-
function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...)
102-
return select_algorithm(eigh_trunc!, A, alg; kwargs...)
97+
for f in (:eigh_full!, :eigh_vals!)
98+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
99+
return default_eigh_algorithm(A; kwargs...)
100+
end
103101
end
104-
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
102+
103+
function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
104+
kwargs...) where {A<:YALAPACK.BlasMat}
105105
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106106
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
107107
end
108-
109-
# Default to LAPACK
110-
function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
111-
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
112-
end

src/interface/lq.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,18 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
6868

6969
# Algorithm selection
7070
# -------------------
71-
for f in (:lq_full, :lq_compact, :lq_null)
72-
f! = Symbol(f, :!)
71+
default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)
72+
function default_lq_algorithm(T::Type; kwargs...)
73+
throw(MethodError(default_lq_algorithm, (T,)))
74+
end
75+
function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
76+
return LAPACK_HouseholderLQ(; kwargs...)
77+
end
78+
79+
for f in (:lq_full!, :lq_compact!, :lq_null!)
7380
@eval begin
74-
function default_algorithm(::typeof($f), A; kwargs...)
75-
return default_algorithm($f!, A; kwargs...)
76-
end
77-
function default_algorithm(::typeof($f!), A; kwargs...)
81+
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
7882
return default_lq_algorithm(A; kwargs...)
7983
end
8084
end
8185
end
82-
83-
# Default to LAPACK
84-
function default_lq_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
85-
return LAPACK_HouseholderLQ(; kwargs...)
86-
end

src/interface/polar.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,16 @@ end
6060

6161
# Algorithm selection
6262
# -------------------
63-
for f in (:left_polar, :right_polar)
64-
f! = Symbol(f, :!)
65-
@eval begin
66-
function default_algorithm(::typeof($f), A; kwargs...)
67-
return default_algorithm($f!, A; kwargs...)
68-
end
69-
function default_algorithm(::typeof($f!), A; kwargs...)
70-
return default_polar_algorithm(A; kwargs...)
71-
end
72-
end
63+
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
64+
function default_polar_algorithm(T::Type; kwargs...)
65+
throw(MethodError(default_polar_algorithm, (T,)))
66+
end
67+
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
68+
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
7369
end
7470

75-
# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`
76-
function default_polar_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
77-
return PolarViaSVD(default_svd_algorithm(A; kwargs...))
71+
for f in (:left_polar!, :right_polar!)
72+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
73+
return default_polar_algorithm(A; kwargs...)
74+
end
7875
end

src/interface/qr.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,19 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
6868

6969
# Algorithm selection
7070
# -------------------
71-
for f in (:qr_full, :qr_compact, :qr_null)
72-
f! = Symbol(f, :!)
71+
default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)
72+
function default_qr_algorithm(T::Type; kwargs...)
73+
throw(MethodError(default_qr_algorithm, (T,)))
74+
end
75+
function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
76+
return LAPACK_HouseholderQR(; kwargs...)
77+
end
78+
79+
for f in (:qr_full!, :qr_compact!, :qr_null!)
7380
@eval begin
74-
function default_algorithm(::typeof($f), A; kwargs...)
75-
return default_algorithm($f!, A; kwargs...)
76-
end
77-
function default_algorithm(::typeof($f!), A; kwargs...)
81+
function default_algorithm(::typeof($f), ::Type{A};
82+
kwargs...) where {A<:YALAPACK.BlasMat}
7883
return default_qr_algorithm(A; kwargs...)
7984
end
8085
end
8186
end
82-
83-
# Default to LAPACK
84-
function default_qr_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...)
85-
return LAPACK_HouseholderQR(; kwargs...)
86-
end

src/interface/schur.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,8 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
5151

5252
# Algorithm selection
5353
# -------------------
54-
for f in (:schur_full, :schur_vals)
55-
f! = Symbol(f, :!)
56-
@eval begin
57-
function default_algorithm(::typeof($f), A; kwargs...)
58-
return default_algorithm($f!, A; kwargs...)
59-
end
60-
function default_algorithm(::typeof($f!), A; kwargs...)
61-
return default_eig_algorithm(A; kwargs...)
62-
end
54+
for f in (:schur_full!, :schur_vals!)
55+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
56+
return default_eig_algorithm(A; kwargs...)
6357
end
6458
end

0 commit comments

Comments
 (0)