Skip to content

Commit 9f014ff

Browse files
committed
Support fp16 in GLA for LQ/QR and test GLA LQ/QR
1 parent 57903cb commit 9f014ff

File tree

6 files changed

+54
-42
lines changed

6 files changed

+54
-42
lines changed

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
5656
return eigvals!(Hermitian(A); sortby = real)
5757
end
5858

59-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
59+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
6060
return GLA_HouseholderQR(; kwargs...)
6161
end
6262

@@ -109,7 +109,7 @@ function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksi
109109
return Q, R
110110
end
111111

112-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
112+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
113113
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
114114
end
115115

test/lq.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using StableRNGs
44
using LinearAlgebra: diag, I, Diagonal
55
using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderLQ
6-
using CUDA, AMDGPU
6+
using CUDA, AMDGPU, GenericLinearAlgebra
77

88
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
@@ -14,9 +14,9 @@ using .TestSuite
1414
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1515

1616
m = 54
17-
for T in BLASFloats, n in (37, m, 63)
17+
for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
1818
TestSuite.seed_rng!(123)
19-
if is_buildkite
19+
if is_buildkite && T BLASFloats
2020
if CUDA.functional()
2121
CUDA_LQ_ALGS = LQViaTransposedQR.(CUSOLVER_HouseholderLQ(; positive = false), CUSOLVER_HouseholderLQ(; positive = true))
2222
TestSuite.test_lq(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
@@ -36,18 +36,24 @@ for T in BLASFloats, n in (37, m, 63)
3636
end
3737
end
3838
else
39-
TestSuite.test_lq(T, (m, n))
40-
LAPACK_LQ_ALGS = (
41-
LAPACK_HouseholderLQ(; positive = false, pivoted = false, blocksize = 1),
42-
LAPACK_HouseholderLQ(; positive = false, pivoted = false, blocksize = 2),
43-
LAPACK_HouseholderLQ(; positive = false, pivoted = true, blocksize = 1),
44-
#LAPACK_HouseholderLQ(; positive=false, pivoted=true, blocksize=2), # not supported
45-
LAPACK_HouseholderLQ(; positive = true, pivoted = false, blocksize = 1),
46-
LAPACK_HouseholderLQ(; positive = true, pivoted = false, blocksize = 2),
47-
LAPACK_HouseholderLQ(; positive = true, pivoted = true, blocksize = 1),
48-
#LAPACK_HouseholderLQ(; positive=true, pivoted=true, blocksize=2), # not supported
49-
)
50-
TestSuite.test_lq_algs(T, (m, n), LAPACK_LQ_ALGS)
39+
if T BLASFloats
40+
TestSuite.test_lq(T, (m, n))
41+
LAPACK_LQ_ALGS = (
42+
LAPACK_HouseholderLQ(; positive = false, pivoted = false, blocksize = 1),
43+
LAPACK_HouseholderLQ(; positive = false, pivoted = false, blocksize = 2),
44+
LAPACK_HouseholderLQ(; positive = false, pivoted = true, blocksize = 1),
45+
#LAPACK_HouseholderLQ(; positive=false, pivoted=true, blocksize=2), # not supported
46+
LAPACK_HouseholderLQ(; positive = true, pivoted = false, blocksize = 1),
47+
LAPACK_HouseholderLQ(; positive = true, pivoted = false, blocksize = 2),
48+
LAPACK_HouseholderLQ(; positive = true, pivoted = true, blocksize = 1),
49+
#LAPACK_HouseholderLQ(; positive=true, pivoted=true, blocksize=2), # not supported
50+
)
51+
TestSuite.test_lq_algs(T, (m, n), LAPACK_LQ_ALGS)
52+
elseif T GenericFloats
53+
TestSuite.test_lq(T, (m, n); test_null = false, test_pivoted = false, test_blocksize = false)
54+
GLA_LQ_ALGS = (LQViaTransposedQR(GLA_HouseholderQR()),)
55+
TestSuite.test_lq_algs(T, (m, n), GLA_LQ_ALGS; test_null = false)
56+
end
5157
end
5258
end
5359
if !is_buildkite

test/qr.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using StableRNGs
44
using LinearAlgebra: diag, I, Diagonal
5-
using CUDA, AMDGPU
5+
using CUDA, AMDGPU, GenericLinearAlgebra
66

77
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
88
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
@@ -13,9 +13,9 @@ using .TestSuite
1313
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1414

1515
m = 54
16-
for T in BLASFloats, n in (37, m, 63)
16+
for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
1717
TestSuite.seed_rng!(123)
18-
if is_buildkite
18+
if is_buildkite && T BLASFloats
1919
if CUDA.functional()
2020
CUDA_QR_ALGS = (CUSOLVER_HouseholderQR(; positive = false), CUSOLVER_HouseholderQR(; positive = true))
2121
TestSuite.test_qr(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
@@ -35,18 +35,24 @@ for T in BLASFloats, n in (37, m, 63)
3535
end
3636
end
3737
else
38-
TestSuite.test_qr(T, (m, n))
39-
LAPACK_QR_ALGS = (
40-
LAPACK_HouseholderQR(; positive = false, pivoted = false, blocksize = 1),
41-
LAPACK_HouseholderQR(; positive = false, pivoted = false, blocksize = 2),
42-
LAPACK_HouseholderQR(; positive = false, pivoted = true, blocksize = 1),
43-
#LAPACK_HouseholderQR(; positive=false, pivoted=true, blocksize=2), # not supported
44-
LAPACK_HouseholderQR(; positive = true, pivoted = false, blocksize = 1),
45-
LAPACK_HouseholderQR(; positive = true, pivoted = false, blocksize = 2),
46-
LAPACK_HouseholderQR(; positive = true, pivoted = true, blocksize = 1),
47-
#LAPACK_HouseholderQR(; positive=true, pivoted=true, blocksize=2), # not supported
48-
)
49-
TestSuite.test_qr_algs(T, (m, n), LAPACK_QR_ALGS)
38+
if T BLASFloats
39+
TestSuite.test_qr(T, (m, n))
40+
LAPACK_QR_ALGS = (
41+
LAPACK_HouseholderQR(; positive = false, pivoted = false, blocksize = 1),
42+
LAPACK_HouseholderQR(; positive = false, pivoted = false, blocksize = 2),
43+
LAPACK_HouseholderQR(; positive = false, pivoted = true, blocksize = 1),
44+
#LAPACK_HouseholderQR(; positive=false, pivoted=true, blocksize=2), # not supported
45+
LAPACK_HouseholderQR(; positive = true, pivoted = false, blocksize = 1),
46+
LAPACK_HouseholderQR(; positive = true, pivoted = false, blocksize = 2),
47+
LAPACK_HouseholderQR(; positive = true, pivoted = true, blocksize = 1),
48+
#LAPACK_HouseholderQR(; positive=true, pivoted=true, blocksize=2), # not supported
49+
)
50+
TestSuite.test_qr_algs(T, (m, n), LAPACK_QR_ALGS)
51+
elseif T GenericFloats
52+
TestSuite.test_qr(T, (m, n); test_null = false, test_pivoted = false, test_blocksize = false)
53+
GLA_QR_ALGS = (GLA_HouseholderQR(),)
54+
TestSuite.test_qr_algs(T, (m, n), GLA_QR_ALGS; test_null = false)
55+
end
5056
end
5157
end
5258
if !is_buildkite

test/testsuite/TestSuite.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ is_positive(alg::MatrixAlgebraKit.LAPACK_HouseholderLQ) = alg.positive
6666
is_pivoted(alg::MatrixAlgebraKit.LAPACK_HouseholderLQ) = alg.pivoted
6767
is_positive(alg::MatrixAlgebraKit.CUSOLVER_HouseholderQR) = alg.positive
6868
is_positive(alg::MatrixAlgebraKit.ROCSOLVER_HouseholderQR) = alg.positive
69-
is_positive(alg::MatrixAlgebraKit.LQViaTransposedQR) = alg.alg.positive
70-
is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = alg.alg.pivoted
69+
is_positive(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_positive(alg.qr_alg)
70+
is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg)
7171

7272
include("qr.jl")
7373
include("lq.jl")

test/testsuite/lq.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
using TestExtras
22

3-
function test_lq(T::Type, sz; kwargs...)
3+
function test_lq(T::Type, sz; test_null = true, kwargs...)
44
summary_str = testargs_summary(T, sz)
55
return @testset "lq $summary_str" begin
66
test_lq_compact(T, sz; kwargs...)
77
test_lq_full(T, sz; kwargs...)
8-
test_lq_null(T, sz; kwargs...)
8+
test_null && test_lq_null(T, sz; kwargs...)
99
end
1010
end
1111

12-
function test_lq_algs(T::Type, sz, algs; kwargs...)
12+
function test_lq_algs(T::Type, sz, algs; test_null = true, kwargs...)
1313
summary_str = testargs_summary(T, sz)
1414
return @testset "lq algorithms $summary_str" begin
1515
test_lq_compact_algs(T, sz, algs; kwargs...)
1616
test_lq_full_algs(T, sz, algs; kwargs...)
17-
test_lq_null_algs(T, sz, algs; kwargs...)
17+
test_null && test_lq_null_algs(T, sz, algs; kwargs...)
1818
end
1919
end
2020

test/testsuite/qr.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
using TestExtras
22

3-
function test_qr(T::Type, sz; kwargs...)
3+
function test_qr(T::Type, sz; test_null = true, kwargs...)
44
summary_str = testargs_summary(T, sz)
55
return @testset "qr $summary_str" begin
66
test_qr_compact(T, sz; kwargs...)
77
test_qr_full(T, sz; kwargs...)
8-
test_qr_null(T, sz; kwargs...)
8+
test_null && test_qr_null(T, sz; kwargs...)
99
end
1010
end
1111

12-
function test_qr_algs(T::Type, sz, algs; kwargs...)
12+
function test_qr_algs(T::Type, sz, algs; test_null = true, kwargs...)
1313
summary_str = testargs_summary(T, sz)
1414
return @testset "qr algorithms $summary_str" begin
1515
test_qr_compact_algs(T, sz, algs; kwargs...)
1616
test_qr_full_algs(T, sz, algs; kwargs...)
17-
test_qr_null_algs(T, sz, algs; kwargs...)
17+
test_null && test_qr_null_algs(T, sz, algs; kwargs...)
1818
end
1919
end
2020

0 commit comments

Comments
 (0)