Skip to content

Commit 13b5b0c

Browse files
committed
add LQViaTransposedQR, CUDA LQ and tests
1 parent 508d284 commit 13b5b0c

File tree

17 files changed

+304
-70
lines changed

17 files changed

+304
-70
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ MatrixAlgebraKitCUDAExt = "CUDA"
1818
Aqua = "0.6, 0.7, 0.8"
1919
ChainRulesCore = "1"
2020
ChainRulesTestUtils = "1"
21+
CUDA = "5"
2122
JET = "0.9"
2223
LinearAlgebra = "1"
2324
SafeTestsets = "0.1"

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7+
using MatrixAlgebraKit: LQViaTransposedQR
78
using CUDA
89
using LinearAlgebra
910
using LinearAlgebra: BlasFloat
1011

1112
include("yacusolver.jl")
1213
include("implementations/qr.jl")
14+
include("implementations/lq.jl")
1315

1416
end
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
function MatrixAlgebraKit.default_lq_algorithm(A::CuMatrix{<:BlasFloat}; kwargs...)
2+
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
3+
return LQViaTransposedQR(qr_alg)
4+
end

ext/MatrixAlgebraKitCUDAExt/implementations/qr.jl

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,6 @@ function MatrixAlgebraKit.default_qr_algorithm(A::CuMatrix{<:BlasFloat}; kwargs.
1111
return CUSOLVER_HouseholderQR(; kwargs...)
1212
end
1313

14-
# Outputs
15-
# -------
16-
function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), A::AbstractMatrix,
17-
::CUSOLVER_HouseholderQR)
18-
m, n = size(A)
19-
Q = similar(A, (m, m))
20-
R = similar(A, (m, n))
21-
return (Q, R)
22-
end
23-
function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), A::AbstractMatrix,
24-
::CUSOLVER_HouseholderQR)
25-
m, n = size(A)
26-
minmn = min(m, n)
27-
Q = similar(A, (m, minmn))
28-
R = similar(A, (minmn, n))
29-
return (Q, R)
30-
end
31-
function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), A::AbstractMatrix,
32-
::CUSOLVER_HouseholderQR)
33-
m, n = size(A)
34-
minmn = min(m, n)
35-
N = similar(A, (m, m - minmn))
36-
return N
37-
end
38-
3914
# Implementation
4015
# --------------
4116
# actual implementation
@@ -80,23 +55,23 @@ function _cusolver_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
8055
if positive # already fix Q even if we do not need R
8156
# TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA
8257
τ .= sign_safe.(diagview(A))
83-
Q .= Q .* transpose(τ)
58+
Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q
59+
Qf .= Qf .* transpose(τ)
8460
end
8561

8662
if computeR
8763
= uppertriangular!(view(A, axes(R)...))
8864
if positive
89-
R̃ .= conj.(τ) .*
65+
R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R
66+
R̃f .= conj.(τ) .* R̃f
9067
end
9168
copyto!(R, R̃)
9269
end
9370
return Q, R
9471
end
9572

9673
function _cusolver_qr_null!(A::AbstractMatrix, N::AbstractMatrix;
97-
positive=false,
98-
pivoted=false,
99-
blocksize=1)
74+
positive=false, blocksize=1)
10075
blocksize > 1 &&
10176
throw(ArgumentError("CUSOLVER does not provide a blocked implementation for a QR decomposition"))
10277
m, n = size(A)

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ module MatrixAlgebraKit
22

33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
5-
using LinearAlgebra: mul!, rmul!, lmul!
5+
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
66
using LinearAlgebra: sylvester
77
using LinearAlgebra: isposdef, ishermitian
88
using LinearAlgebra: Diagonal, diag, diagind
99
using LinearAlgebra: UpperTriangular, LowerTriangular
10-
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!, tril!, rdiv!, ldiv!
10+
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

1212
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
1313
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!

src/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ macro algdef(name)
9898
return $name{typeof(kw)}(kw)
9999
end
100100
function Base.show(io::IO, alg::$name)
101-
return _show_alg(io, alg)
101+
return ($_show_alg)(io, alg)
102102
end
103103

104104
Core.@__doc__ $name

src/implementations/eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ end
3030

3131
# Outputs
3232
# -------
33-
function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::LAPACK_EigAlgorithm)
33+
function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::AbstractAlgorithm)
3434
n = size(A, 1) # square check will happen later
3535
Tc = complex(eltype(A))
3636
D = Diagonal(similar(A, Tc, n))
3737
V = similar(A, Tc, (n, n))
3838
return (D, V)
3939
end
40-
function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::LAPACK_EigAlgorithm)
40+
function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlgorithm)
4141
n = size(A, 1) # square check will happen later
4242
Tc = complex(eltype(A))
4343
D = similar(A, Tc, n)

src/implementations/eigh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ end
2929

3030
# Outputs
3131
# -------
32-
function initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::LAPACK_EighAlgorithm)
32+
function initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::AbstractAlgorithm)
3333
n = size(A, 1) # square check will happen later
3434
D = Diagonal(similar(A, real(eltype(A)), n))
3535
V = similar(A, (n, n))
3636
return (D, V)
3737
end
38-
function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::LAPACK_EighAlgorithm)
38+
function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAlgorithm)
3939
n = size(A, 1) # square check will happen later
4040
D = similar(A, real(eltype(A)), n)
4141
return D

src/implementations/lq.jl

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ end
4242

4343
# Outputs
4444
# -------
45-
function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::LAPACK_HouseholderLQ)
45+
function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::AbstractAlgorithm)
4646
m, n = size(A)
4747
L = similar(A, (m, n))
4848
Q = similar(A, (n, n))
4949
return (L, Q)
5050
end
51-
function initialize_output(::typeof(lq_compact!), A::AbstractMatrix, ::LAPACK_HouseholderLQ)
51+
function initialize_output(::typeof(lq_compact!), A::AbstractMatrix, ::AbstractAlgorithm)
5252
m, n = size(A)
5353
minmn = min(m, n)
5454
L = similar(A, (m, minmn))
5555
Q = similar(A, (minmn, n))
5656
return (L, Q)
5757
end
58-
function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::LAPACK_HouseholderLQ)
58+
function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgorithm)
5959
m, n = size(A)
6060
minmn = min(m, n)
6161
Nᴴ = similar(A, (n - minmn, n))
@@ -71,17 +71,34 @@ function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
7171
_lapack_lq!(A, L, Q; alg.kwargs...)
7272
return L, Q
7373
end
74+
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
75+
check_input(lq_full!, A, LQ)
76+
L, Q = LQ
77+
lq_via_qr!(A, L, Q, alg.qr_alg)
78+
return L, Q
79+
end
7480
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
7581
check_input(lq_compact!, A, LQ)
7682
L, Q = LQ
7783
_lapack_lq!(A, L, Q; alg.kwargs...)
7884
return L, Q
7985
end
86+
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
87+
check_input(lq_compact!, A, LQ)
88+
L, Q = LQ
89+
lq_via_qr!(A, L, Q, alg.qr_alg)
90+
return L, Q
91+
end
8092
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
8193
check_input(lq_null!, A, Nᴴ)
8294
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
8395
return Nᴴ
8496
end
97+
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
98+
check_input(lq_null!, A, Nᴴ)
99+
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
100+
return Nᴴ
101+
end
85102

86103
function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
87104
positive=false,
@@ -158,3 +175,31 @@ function _lapack_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix;
158175
end
159176
return Nᴴ
160177
end
178+
179+
# LQ via transposition and QR
180+
function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
181+
qr_alg::AbstractAlgorithm)
182+
m, n = size(A)
183+
minmn = min(m, n)
184+
At = adjoint!(similar(A'), A)
185+
Qt = (A === Q) ? At : similar(Q')
186+
Lt = similar(L')
187+
if size(Q) == (n, n)
188+
Qt, Lt = qr_full!(At, (Qt, Lt), qr_alg)
189+
else
190+
Qt, Lt = qr_compact!(At, (Qt, Lt), qr_alg)
191+
end
192+
adjoint!(Q, Qt)
193+
!isempty(L) && adjoint!(L, Lt)
194+
return L, Q
195+
end
196+
197+
function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractAlgorithm)
198+
m, n = size(A)
199+
minmn = min(m, n)
200+
At = adjoint!(similar(A'), A)
201+
Nt = similar(N')
202+
Nt = qr_null!(At, Nt, qr_alg)
203+
!isempty(N) && adjoint!(N, Nt)
204+
return N
205+
end

src/implementations/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ end
3030

3131
# Outputs
3232
# -------
33-
function initialize_output(::typeof(left_polar!), A::AbstractMatrix, ::PolarViaSVD)
33+
function initialize_output(::typeof(left_polar!), A::AbstractMatrix, ::AbstractAlgorithm)
3434
m, n = size(A)
3535
W = similar(A)
3636
P = similar(A, (n, n))
3737
return (W, P)
3838
end
39-
function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::PolarViaSVD)
39+
function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::AbstractAlgorithm)
4040
m, n = size(A)
4141
P = similar(A, (m, m))
4242
Wᴴ = similar(A)

0 commit comments

Comments
 (0)