Skip to content

Commit 38f0c55

Browse files
committed
Use TestSuite for QR and LQ
1 parent 8bded8b commit 38f0c55

File tree

16 files changed

+724
-1363
lines changed

16 files changed

+724
-1363
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ GenericSchur = "0.5.6"
3333
JET = "0.9, 0.10"
3434
LinearAlgebra = "1"
3535
Mooncake = "0.4.183"
36+
Random = "1"
3637
SafeTestsets = "0.1"
3738
StableRNGs = "1"
3839
Test = "1"
39-
TestExtras = "0.2,0.3"
40+
TestExtras = "0.3.2"
4041
Zygote = "0.7"
4142
julia = "1.10"
4243

@@ -47,11 +48,12 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4748
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4849
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4950
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
51+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5052
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5153
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5254
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5355
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5456
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5557

5658
[targets]
57-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
59+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
5656
return eigvals!(Hermitian(A); sortby = real)
5757
end
5858

59-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
59+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
6060
return GLA_HouseholderQR(; kwargs...)
6161
end
6262

@@ -109,7 +109,7 @@ function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksi
109109
return Q, R
110110
end
111111

112-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
112+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
113113
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
114114
end
115115

src/implementations/lq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ function _lapack_lq!(
153153
computeL = length(L) > 0
154154
inplaceQ = Q === A
155155

156-
if pivoted
157-
throw(ArgumentError("LAPACK does not provide an implementation for a pivoted LQ decomposition"))
156+
if pivoted && (blocksize > 1)
157+
throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted LQ decomposition"))
158158
end
159159
if inplaceQ && (computeL || positive || blocksize > 1 || n < m)
160160
throw(ArgumentError("inplace Q only supported if matrix is wide (`m <= n`), L is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`"))

src/implementations/qr.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,12 @@ function _gpu_unmqr!(
270270
end
271271

272272
function _gpu_qr!(
273-
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = false, blocksize = 1
273+
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = false, blocksize = 1, pivoted = false
274274
)
275275
blocksize > 1 &&
276276
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
277+
pivoted &&
278+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
277279
m, n = size(A)
278280
minmn = min(m, n)
279281
computeR = length(R) > 0
@@ -309,10 +311,12 @@ function _gpu_qr!(
309311
end
310312

311313
function _gpu_qr_null!(
312-
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1
314+
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1, pivoted = false
313315
)
314316
blocksize > 1 &&
315317
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
318+
pivoted &&
319+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
316320
m, n = size(A)
317321
minmn = min(m, n)
318322
fill!(N, zero(eltype(N)))

test/amd/lq.jl

Lines changed: 0 additions & 163 deletions
This file was deleted.

0 commit comments

Comments
 (0)