Skip to content

Commit 430e91b

Browse files
committed
Docs, tests, style improvements
1 parent cabd897 commit 430e91b

File tree

8 files changed

+43
-14
lines changed

8 files changed

+43
-14
lines changed

docs/src/dev_interface.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
```@meta
2+
CurrentModule = MatrixAlgebraKit
3+
CollapsedDocStrings = true
4+
```
5+
6+
# Developer Interface
7+
8+
MatrixAlgebraKit.jl provides a developer interface for specifying custom algorithm backends and selecting default algorithms.
9+
10+
```@docs; canonical=false
11+
MatrixAlgebraKit.default_algorithm
12+
MatrixAlgebraKit.select_algorithm
13+
```

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3030
LAPACK_DivideAndConquer, LAPACK_Jacobi
3131
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
3232

33+
public default_truncation, select_algorithm
34+
3335
include("common/defaults.jl")
3436
include("common/initialization.jl")
3537
include("common/pullbacks.jl")

src/algorithms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function _show_alg(io::IO, alg::Algorithm)
5454
end
5555

5656
@doc """
57-
select_algorithm(f, A; kwargs...)
57+
MatrixAlgebraKit.select_algorithm(f, A; kwargs...)
5858
5959
Given some keyword arguments and an input `A`, decide on an algorithm to use for
6060
implementing the function `f` on inputs of type `A`.
@@ -82,7 +82,7 @@ function _select_algorithm(f, A, alg::AbstractAlgorithm; kwargs...)
8282
return alg
8383
end
8484
function _select_algorithm(f, A, alg::Symbol; kwargs...)
85-
return _select_algorithm(f, A, Algorithm{alg}; kwargs...)
85+
return _select_algorithm(f, A, Algorithm{alg}(; kwargs...))
8686
end
8787
function _select_algorithm(f, A, alg::Type; kwargs...)
8888
return _select_algorithm(f, A, alg(; kwargs...))
@@ -97,7 +97,7 @@ function _select_algorithm(f, A, alg; kwargs...)
9797
end
9898

9999
@doc """
100-
default_algorithm(f, A; kwargs...)
100+
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
101101
102102
Select the default algorithm for a given factorization function `f` and input `A`.
103103
In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified

src/implementations/truncation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ Trivial truncation strategy that keeps all values, mostly for testing purposes.
3232
"""
3333
struct NoTruncation <: TruncationStrategy end
3434

35-
function to_truncationstrategy(trunc::TruncationStrategy)
35+
function select_truncation(trunc::TruncationStrategy)
3636
return trunc
3737
end
38-
function to_truncationstrategy(trunc::NamedTuple)
38+
function select_truncation(trunc::NamedTuple)
3939
return TruncationStrategy(; trunc...)
4040
end
41-
function to_truncationstrategy(trunc::Nothing)
41+
function select_truncation(trunc::Nothing)
4242
return NoTruncation()
4343
end
44-
function to_truncationstrategy(trunc)
44+
function select_truncation(trunc)
4545
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
4646
end
4747

src/interface/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function select_algorithm(::typeof(eig_trunc), A; kwargs...)
104104
end
105105
function select_algorithm(::typeof(eig_trunc!), A; trunc=nothing, kwargs...)
106106
alg_eig = select_algorithm(eig_full!, A; kwargs...)
107-
return TruncatedAlgorithm(alg_eig, to_truncationstrategy(trunc))
107+
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
108108
end
109109

110110
# Default to LAPACK

src/interface/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function select_algorithm(::typeof(eigh_trunc), A; kwargs...)
103103
end
104104
function select_algorithm(::typeof(eigh_trunc!), A; trunc=nothing, kwargs...)
105105
alg_eigh = select_algorithm(eigh_full!, A; kwargs...)
106-
return TruncatedAlgorithm(alg_eigh, to_truncationstrategy(trunc))
106+
return TruncatedAlgorithm(alg_eigh, select_truncation(trunc))
107107
end
108108

109109
# Default to LAPACK

src/interface/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function select_algorithm(::typeof(svd_trunc), A; kwargs...)
107107
end
108108
function select_algorithm(::typeof(svd_trunc!), A; trunc=nothing, kwargs...)
109109
alg_svd = select_algorithm(svd_compact!, A; kwargs...)
110-
return TruncatedAlgorithm(alg_svd, to_truncationstrategy(trunc))
110+
return TruncatedAlgorithm(alg_svd, select_truncation(trunc))
111111
end
112112

113113
# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}`

test/eig.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,35 @@ using MatrixAlgebraKit: diagview
88
@testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
1010
m = 54
11-
for alg in (LAPACK_Simple(), LAPACK_Expert())
11+
for alg in
12+
(LAPACK_Simple(), LAPACK_Expert(), LAPACK_Simple, LAPACK_Expert, :LAPACK_Simple,
13+
:LAPACK_Expert)
1214
A = randn(rng, T, m, m)
1315
Tc = complex(T)
1416

15-
D, V = @constinferred eig_full(A; alg)
17+
alg′ = if (alg isa Type) || (alg isa Symbol)
18+
# These cases aren't inferable right now.
19+
MatrixAlgebraKit.select_algorithm(eig_full!, A; alg)
20+
else
21+
@constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A; alg)
22+
end
23+
24+
D, V = if (alg isa Type) || (alg isa Symbol)
25+
# These cases aren't inferable right now.
26+
eig_full(A; alg)
27+
else
28+
@constinferred eig_full(A; alg)
29+
end
1630
@test eltype(D) == eltype(V) == Tc
1731
@test A * V V * D
1832

1933
Ac = similar(A)
20-
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg)
34+
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg)
2135
@test D2 === D
2236
@test V2 === V
2337
@test A * V V * D
2438

25-
Dc = @constinferred eig_vals(A, alg)
39+
Dc = @constinferred eig_vals(A, alg)
2640
@test eltype(Dc) == Tc
2741
@test D Diagonal(Dc)
2842
end

0 commit comments

Comments
 (0)