Skip to content

Commit d4099c5

Browse files
authored
Make select_algorithm more agnostic about being in the object or type domain (#32)
1 parent 2acaccd commit d4099c5

File tree

11 files changed

+103
-34
lines changed

11 files changed

+103
-34
lines changed

src/algorithms.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ passed as the third positional argument in the form of a `NamedTuple`.
7777
""" select_algorithm
7878

7979
function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
80-
return select_algorithm(f, typeof(A), alg; kwargs...)
81-
end
82-
function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg}
8380
if isnothing(alg)
8481
return default_algorithm(f, A; kwargs...)
8582
elseif alg isa Symbol
@@ -193,10 +190,24 @@ macro functiondef(f)
193190
end
194191

195192
# define fallbacks for algorithm selection
196-
@inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg;
197-
kwargs...) where {Alg,A}
193+
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
198194
return select_algorithm($f!, A, alg; kwargs...)
199195
end
196+
# define default algorithm fallbacks for out-of-place functions
197+
# in terms of the corresponding in-place function
198+
@inline function default_algorithm(::typeof($f), A; kwargs...)
199+
return default_algorithm($f!, A; kwargs...)
200+
end
201+
# define default algorithm fallbacks for out-of-place functions
202+
# in terms of the corresponding in-place function for types,
203+
# in principle this is covered by the definition above but
204+
# it is necessary to avoid ambiguity errors with the generic definitions:
205+
# ```julia
206+
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
207+
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
208+
# throw(MethodError(default_algorithm, (f, T)))
209+
# end
210+
# ```
200211
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
201212
return default_algorithm($f!, A; kwargs...)
202213
end

src/interface/eig.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ for f in (:eig_full!, :eig_vals!)
9999
end
100100
end
101101

102-
function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing,
103-
kwargs...) where {A<:YALAPACK.BlasMat}
104-
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
105-
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
102+
function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...)
103+
if alg isa TruncatedAlgorithm
104+
isnothing(trunc) ||
105+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
106+
return alg
107+
else
108+
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
109+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
110+
end
106111
end

src/interface/eigh.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,13 @@ for f in (:eigh_full!, :eigh_vals!)
100100
end
101101
end
102102

103-
function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing,
104-
kwargs...) where {A<:YALAPACK.BlasMat}
105-
alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...)
106-
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
103+
function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...)
104+
if alg isa TruncatedAlgorithm
105+
isnothing(trunc) ||
106+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
107+
return alg
108+
else
109+
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
110+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
111+
end
107112
end

src/interface/lq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
7777
end
7878

7979
for f in (:lq_full!, :lq_compact!, :lq_null!)
80-
@eval begin
81-
function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
82-
return default_lq_algorithm(A; kwargs...)
83-
end
80+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
81+
return default_lq_algorithm(A; kwargs...)
8482
end
8583
end

src/interface/polar.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ end
6161
# Algorithm selection
6262
# -------------------
6363
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
64-
function default_polar_algorithm(T::Type; kwargs...)
65-
throw(MethodError(default_polar_algorithm, (T,)))
66-
end
67-
function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
64+
function default_polar_algorithm(::Type{T}; kwargs...) where {T}
6865
return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...))
6966
end
7067

src/interface/qr.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat}
7777
end
7878

7979
for f in (:qr_full!, :qr_compact!, :qr_null!)
80-
@eval begin
81-
function default_algorithm(::typeof($f), ::Type{A};
82-
kwargs...) where {A<:YALAPACK.BlasMat}
83-
return default_qr_algorithm(A; kwargs...)
84-
end
80+
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
81+
return default_qr_algorithm(A; kwargs...)
8582
end
8683
end

src/interface/svd.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,13 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!)
104104
end
105105
end
106106

107-
function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing,
108-
kwargs...) where {A<:YALAPACK.BlasMat}
109-
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
110-
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
107+
function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...)
108+
if alg isa TruncatedAlgorithm
109+
isnothing(trunc) ||
110+
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
111+
return alg
112+
else
113+
alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...)
114+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
115+
end
111116
end

test/algorithms.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm,
5-
default_algorithm, select_algorithm
5+
TruncationKeepBelow, default_algorithm, select_algorithm
66

77
@testset "default_algorithm" begin
88
A = randn(3, 3)
@@ -50,6 +50,12 @@ end
5050
NoTruncation())
5151
end
5252

53+
alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0))
54+
for f in (eig_trunc!, eigh_trunc!, svd_trunc!)
55+
@test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg
56+
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2))
57+
end
58+
5359
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
5460
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) ===
5561
LAPACK_DivideAndConquer()

test/eig.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: Diagonal
6-
using MatrixAlgebraKit: diagview
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77

88
@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
@@ -57,3 +57,17 @@ end
5757
@test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
5858
end
5959
end
60+
61+
@testset "eig_trunc! specify truncation algorithm T = $T" for T in
62+
(Float32, Float64, ComplexF32,
63+
ComplexF64)
64+
rng = StableRNG(123)
65+
m = 4
66+
V = randn(rng, T, m, m)
67+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
68+
A = V * D * inv(V)
69+
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
70+
D2, V2 = @constinferred eig_trunc(A; alg)
71+
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
72+
@test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2))
73+
end

test/eigh.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I
6-
using MatrixAlgebraKit: diagview
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview
77

88
@testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
@@ -62,3 +62,19 @@ end
6262
@test V2 * (V2' * V1) V1
6363
end
6464
end
65+
66+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in
67+
(Float32, Float64,
68+
ComplexF32,
69+
ComplexF64)
70+
rng = StableRNG(123)
71+
m = 4
72+
V = qr_compact(randn(rng, T, m, m))[1]
73+
D = Diagonal([0.9, 0.3, 0.1, 0.01])
74+
A = V * D * V'
75+
A = (A + A') / 2
76+
alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2))
77+
D2, V2 = @constinferred eigh_trunc(A; alg)
78+
@test diagview(D2) diagview(D)[1:2] rtol = sqrt(eps(real(T)))
79+
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
80+
end

0 commit comments

Comments
 (0)