Skip to content

Commit df4b2b8

Browse files
committed
Add more tests
1 parent 7dff266 commit df4b2b8

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

src/algorithms.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,20 @@ can't be passed to `MatrixAlgebraKit.select_algorithm`.
7272
"""
7373
function select_algorithm end
7474

75-
Base.@constprop :aggressive function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg}
75+
Base.@constprop :aggressive function select_algorithm(f::F, A, alg::Alg=nothing;
76+
kwargs...) where {F,Alg}
7677
return _select_algorithm(f, A, alg; kwargs...)
7778
end
7879

7980
function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F}
8081
return default_algorithm(f, A; kwargs...)
8182
end
82-
Base.@constprop :aggressive function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F}
83+
Base.@constprop :aggressive function _select_algorithm(f::F, A, alg::Symbol;
84+
kwargs...) where {F}
8385
return Algorithm{alg}(; kwargs...)
8486
end
85-
Base.@constprop :aggressive function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg}
87+
Base.@constprop :aggressive function _select_algorithm(f::F, A, ::Type{Alg};
88+
kwargs...) where {F,Alg}
8689
return Alg(; kwargs...)
8790
end
8891
function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F}
@@ -174,7 +177,8 @@ macro functiondef(f)
174177

175178
return esc(quote
176179
# out of place to inplace
177-
Base.@constprop :aggressive $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
180+
Base.@constprop :aggressive $f(A; kwargs...) = $f!(copy_input($f, A);
181+
kwargs...)
178182
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
179183

180184
# fill in arguments

test/algorithms.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm,
5+
default_algorithm, select_algorithm
6+
7+
@testset "default_algorithm" begin
8+
A = randn(3, 3)
9+
for f in (svd_compact!, svd_compact, svd_full!, svd_full)
10+
@test @constinferred(default_algorithm(f, A)) === LAPACK_DivideAndConquer()
11+
end
12+
for f in (eig_full!, eig_full, eig_vals!, eig_vals)
13+
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
14+
end
15+
for f in (eigh_full!, eigh_full, eigh_vals!, eigh_vals)
16+
@test @constinferred(default_algorithm(f, A)) ===
17+
LAPACK_MultipleRelativelyRobustRepresentations()
18+
end
19+
for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null)
20+
@test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderLQ()
21+
end
22+
for f in (left_polar!, left_polar, right_polar!, right_polar)
23+
@test @constinferred(default_algorithm(f, A)) ==
24+
PolarViaSVD(LAPACK_DivideAndConquer())
25+
end
26+
for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null)
27+
@test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderQR()
28+
end
29+
for f in (schur_full!, schur_full, schur_vals!, schur_vals)
30+
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
31+
end
32+
33+
@test @constinferred(default_algorithm(qr_compact!, A; blocksize=2)) ===
34+
LAPACK_HouseholderQR(; blocksize=2)
35+
end
36+
37+
@testset "select_algorithm" begin
38+
A = randn(3, 3)
39+
for f in (svd_trunc!, svd_trunc)
40+
@test @constinferred(select_algorithm(f, A)) ===
41+
TruncatedAlgorithm(LAPACK_DivideAndConquer(), NoTruncation())
42+
end
43+
for f in (eig_trunc!, eig_trunc)
44+
@test @constinferred(select_algorithm(f, A)) ===
45+
TruncatedAlgorithm(LAPACK_Expert(), NoTruncation())
46+
end
47+
for f in (eigh_trunc!, eigh_trunc)
48+
@test @constinferred(select_algorithm(f, A)) ===
49+
TruncatedAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(),
50+
NoTruncation())
51+
end
52+
53+
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
54+
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) ===
55+
LAPACK_DivideAndConquer()
56+
@test @constinferred(select_algorithm(svd_compact!, A, :LAPACK_QRIteration)) ===
57+
LAPACK_QRIteration()
58+
@test @constinferred(select_algorithm(svd_compact!, A, LAPACK_QRIteration)) ===
59+
LAPACK_QRIteration()
60+
@test @constinferred(select_algorithm(svd_compact!, A, LAPACK_QRIteration())) ===
61+
LAPACK_QRIteration()
62+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using SafeTestsets
22

3+
@safetestset "Algorithms" begin
4+
include("algorithms.jl")
5+
end
36
@safetestset "Truncate" begin
47
include("truncate.jl")
58
end

0 commit comments

Comments
 (0)