Skip to content

Commit 3848f2b

Browse files
authored
Use Testsuite for Schur (#127)
* Use Testsuite for Schur * Schur full and vals for more types
1 parent 4212b74 commit 3848f2b

File tree

6 files changed

+103
-146
lines changed

6 files changed

+103
-146
lines changed

ext/MatrixAlgebraKitGenericSchurExt.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ module MatrixAlgebraKitGenericSchurExt
22

33
using MatrixAlgebraKit
44
using MatrixAlgebraKit: check_input
5-
using LinearAlgebra: Diagonal
5+
using LinearAlgebra: Diagonal, sorteig!
66
using GenericSchur
77

8-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
8+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
99
return GS_QRIteration(; kwargs...)
1010
end
1111

@@ -21,4 +21,21 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
2121
return GenericSchur.eigvals!(A)
2222
end
2323

24+
function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration)
25+
check_input(schur_full!, A, TZv, alg)
26+
T, Z, vals = TZv
27+
S = GenericSchur.gschur(A)
28+
copyto!(T, S.T)
29+
copyto!(Z, S.Z)
30+
copyto!(vals, S.values)
31+
return T, Z, vals
32+
end
33+
34+
function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration)
35+
check_input(schur_vals!, A, vals, alg)
36+
S = GenericSchur.gschur(A)
37+
copyto!(vals, sorteig!(S.values))
38+
return vals
39+
end
40+
2441
end

test/genericschur/eig.jl

Lines changed: 0 additions & 116 deletions
This file was deleted.

test/runtests.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ if !is_buildkite
2222
@safetestset "Generalized Eigenvalue Decomposition" begin
2323
include("gen_eig.jl")
2424
end
25-
@safetestset "Schur Decomposition" begin
26-
include("schur.jl")
27-
end
2825
@safetestset "Image and Null Space" begin
2926
include("orthnull.jl")
3027
end
@@ -55,10 +52,6 @@ if !is_buildkite
5552
include("genericlinearalgebra/eigh.jl")
5653
end
5754

58-
using GenericSchur
59-
@safetestset "General Eigenvalue Decomposition" begin
60-
include("genericschur/eig.jl")
61-
end
6255
end
6356

6457
@safetestset "QR / LQ Decomposition" begin
@@ -71,6 +64,9 @@ end
7164
@safetestset "Projections" begin
7265
include("projections.jl")
7366
end
67+
@safetestset "Schur Decomposition" begin
68+
include("schur.jl")
69+
end
7470

7571
using CUDA
7672
if CUDA.functional()

test/schur.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,32 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: I
5+
using LinearAlgebra: I, Diagonal
66

7-
@testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
8-
rng = StableRNG(123)
9-
m = 54
10-
for alg in (LAPACK_Simple(), LAPACK_Expert())
11-
A = randn(rng, T, m, m)
12-
Tc = complex(T)
7+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
GenericFloats = (BigFloat, Complex{BigFloat})
139

14-
TA, Z, vals = @constinferred schur_full(A; alg)
15-
@test eltype(TA) == eltype(Z) == T
16-
@test eltype(vals) == Tc
17-
@test isisometric(Z)
18-
@test A * Z Z * TA
10+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
11+
using .TestSuite
1912

20-
Ac = similar(A)
21-
TA2, Z2, vals2 = @constinferred schur_full!(copy!(Ac, A), (TA, Z, vals), alg)
22-
@test TA2 === TA
23-
@test Z2 === Z
24-
@test vals2 === vals
25-
@test A * Z Z * TA
13+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
2614

27-
valsc = @constinferred schur_vals(A, alg)
28-
@test eltype(valsc) == Tc
29-
@test valsc eig_vals(A, alg)
15+
m = 54
16+
for T in (BLASFloats..., GenericFloats...)
17+
TestSuite.seed_rng!(123)
18+
if T BLASFloats
19+
#=if CUDA.functional()
20+
TestSuite.test_schur(CuMatrix{T}, (m, m); test_blocksize = false)
21+
TestSuite.test_schur(Diagonal{T, CuVector{T}}, m; test_blocksize = false)
22+
end
23+
if AMDGPU.functional()
24+
TestSuite.test_schur(ROCMatrix{T}, (m, m); test_blocksize = false)
25+
TestSuite.test_schur(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
26+
end=# # not yet supported
27+
end
28+
if !is_buildkite
29+
TestSuite.test_schur(T, (m, m))
30+
#AT = Diagonal{T, Vector{T}}
31+
#TestSuite.test_schur(AT, m) # not supported yet
3032
end
3133
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,6 @@ include("qr.jl")
7373
include("lq.jl")
7474
include("polar.jl")
7575
include("projections.jl")
76+
include("schur.jl")
7677

7778
end

test/testsuite/schur.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using TestExtras
2+
using GenericSchur
3+
4+
function test_schur(T::Type, sz; kwargs...)
5+
summary_str = testargs_summary(T, sz)
6+
return @testset "schur $summary_str" begin
7+
test_schur_full(T, sz; kwargs...)
8+
test_schur_vals(T, sz; kwargs...)
9+
end
10+
end
11+
12+
function test_schur_full(
13+
T::Type, sz;
14+
atol::Real = 0, rtol::Real = precision(T),
15+
kwargs...
16+
)
17+
summary_str = testargs_summary(T, sz)
18+
return @testset "schur_full! $summary_str" begin
19+
A = instantiate_matrix(T, sz)
20+
Ac = deepcopy(A)
21+
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
22+
23+
TA, Z, vals = @testinferred schur_full(A)
24+
@test eltype(TA) == eltype(Z) == eltype(T)
25+
@test eltype(vals) == Tc
26+
@test isisometric(Z)
27+
@test A * Z Z * TA
28+
29+
TA2, Z2, vals2 = @testinferred schur_full!(Ac, (TA, Z, vals))
30+
@test TA2 === TA
31+
@test Z2 === Z
32+
@test vals2 === vals
33+
@test A * Z Z * TA
34+
end
35+
end
36+
37+
function test_schur_vals(
38+
T::Type, sz;
39+
atol::Real = 0, rtol::Real = precision(T),
40+
kwargs...
41+
)
42+
summary_str = testargs_summary(T, sz)
43+
return @testset "schur_vals! $summary_str" begin
44+
A = instantiate_matrix(T, sz)
45+
Ac = deepcopy(A)
46+
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
47+
48+
valsc = @testinferred schur_vals(A)
49+
@test eltype(valsc) == Tc
50+
@test valsc eig_vals(A)
51+
52+
valsc = similar(A, Tc, size(A, 1))
53+
valsc = @testinferred schur_vals!(Ac, valsc)
54+
@test eltype(valsc) == Tc
55+
@test valsc eig_vals(A)
56+
end
57+
end

0 commit comments

Comments
 (0)