Skip to content

Commit 7531d17

Browse files
authored
Support for Diagonal (#50)
* Add DiagonalAlgorithm * add Diagonal QR implementation and tests * Add Diagonal LQ implementation and tests * Add Diagonal eig implementation and tests * Add Diagonal eigh implementation and tests * Add Diagonal svd implementation and tests * Make JET happy * Add hermitian/symmetric checks * GPU-friendly QR/LQ * GPU-friendly SVD + correct gaugefix * Bump v0.3.1
1 parent c9727e4 commit 7531d17

File tree

18 files changed

+595
-73
lines changed

18 files changed

+595
-73
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/MatrixAlgebraKit.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
66
using LinearAlgebra: sylvester
7-
using LinearAlgebra: isposdef, ishermitian
8-
using LinearAlgebra: Diagonal, diag, diagind
7+
using LinearAlgebra: isposdef, ishermitian, issymmetric
8+
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular
1010
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

@@ -35,7 +35,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3535
LQViaTransposedQR,
3636
CUSOLVER_Simple,
3737
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
38-
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection
38+
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
39+
DiagonalAlgorithm
3940
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
4041

4142
VERSION >= v"1.11.0-DEV.469" &&

src/implementations/eig.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
function copy_input(::typeof(eig_full), A::AbstractMatrix)
44
return copy!(similar(A, float(eltype(A))), A)
55
end
6-
function copy_input(::typeof(eig_vals), A::AbstractMatrix)
6+
function copy_input(::typeof(eig_vals), A)
77
return copy_input(eig_full, A)
88
end
99
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
1010

11+
copy_input(::typeof(eig_full), A::Diagonal) = copy(A)
12+
1113
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1214
m, n = size(A)
1315
m == n || throw(DimensionMismatch("square input matrix expected"))
@@ -28,6 +30,28 @@ function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgori
2830
return nothing
2931
end
3032

33+
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
34+
m, n = size(A)
35+
@assert m == n && isdiag(A)
36+
D, V = DV
37+
@assert D isa Diagonal && V isa Diagonal
38+
@check_size(D, (m, m))
39+
@check_size(V, (m, m))
40+
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
41+
@check_scalar(D, A)
42+
@check_scalar(V, A)
43+
return nothing
44+
end
45+
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
46+
m, n = size(A)
47+
@assert m == n && isdiag(A)
48+
@assert D isa AbstractVector
49+
@check_size(D, (n,))
50+
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
51+
@check_scalar(D, A)
52+
return nothing
53+
end
54+
3155
# Outputs
3256
# -------
3357
function initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::AbstractAlgorithm)
@@ -47,9 +71,15 @@ function initialize_output(::typeof(eig_trunc!), A::AbstractMatrix, alg::Truncat
4771
return initialize_output(eig_full!, A, alg.alg)
4872
end
4973

74+
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
75+
return A, similar(A)
76+
end
77+
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
78+
return diagview(A)
79+
end
80+
5081
# Implementation
5182
# --------------
52-
# actual implementation
5383
function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
5484
check_input(eig_full!, A, DV, alg)
5585
D, V = DV
@@ -83,6 +113,24 @@ function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
83113
return truncate!(eig_trunc!, (D, V), alg.trunc)
84114
end
85115

116+
# Diagonal logic
117+
# --------------
118+
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal,Diagonal}, alg::DiagonalAlgorithm)
119+
check_input(eig_full!, A, (D, V), alg)
120+
D === A || copy!(D, A)
121+
one!(V)
122+
return D, V
123+
end
124+
125+
function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
126+
check_input(eig_vals!, A, D, alg)
127+
Ad = diagview(A)
128+
D === Ad || copy!(D, Ad)
129+
return D
130+
end
131+
132+
# GPU logic
133+
# ---------
86134
_gpu_geev!(A::AbstractMatrix, D, V) = throw(MethodError(_gpu_geev!, (A, D, V)))
87135

88136
function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm)

src/implementations/eigh.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ function copy_input(::typeof(eigh_vals), A::AbstractMatrix)
88
end
99
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
1010

11+
copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
12+
1113
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
1214
m, n = size(A)
1315
m == n || throw(DimensionMismatch("square input matrix expected"))
@@ -21,6 +23,29 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgo
2123
end
2224
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
2325
m, n = size(A)
26+
m == n || throw(DimensionMismatch("square input matrix expected"))
27+
@assert D isa AbstractVector
28+
@check_size(D, (n,))
29+
@check_scalar(D, A, real)
30+
return nothing
31+
end
32+
33+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
34+
m, n = size(A)
35+
@assert m == n && isdiag(A)
36+
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
37+
D, V = DV
38+
@assert D isa Diagonal && V isa Diagonal
39+
@check_size(D, (m, m))
40+
@check_scalar(D, A, real)
41+
@check_size(V, (m, m))
42+
@check_scalar(V, A)
43+
return nothing
44+
end
45+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
46+
m, n = size(A)
47+
@assert m == n && isdiag(A)
48+
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
2449
@assert D isa AbstractVector
2550
@check_size(D, (n,))
2651
@check_scalar(D, A, real)
@@ -45,6 +70,13 @@ function initialize_output(::typeof(eigh_trunc!), A::AbstractMatrix,
4570
return initialize_output(eigh_full!, A, alg.alg)
4671
end
4772

73+
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
74+
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
75+
end
76+
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
77+
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
78+
end
79+
4880
# Implementation
4981
# --------------
5082
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
@@ -85,6 +117,25 @@ function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
85117
return truncate!(eigh_trunc!, (D, V), alg.trunc)
86118
end
87119

120+
# Diagonal logic
121+
# --------------
122+
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
123+
check_input(eigh_full!, A, DV, alg)
124+
D, V = DV
125+
D === A || (diagview(D) .= real.(diagview(A)))
126+
one!(V)
127+
return D, V
128+
end
129+
130+
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
131+
check_input(eigh_vals!, A, D, alg)
132+
Ad = diagview(A)
133+
D === Ad || (D .= real.(Ad))
134+
return D
135+
end
136+
137+
# GPU logic
138+
# ---------
88139
_gpu_heevj!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevj!, (A, Dd, V)))
89140
_gpu_heevd!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heevd!, (A, Dd, V)))
90141
_gpu_heev!(A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...) = throw(MethodError(_gpu_heev!, (A, Dd, V)))

src/implementations/lq.jl

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# Inputs
22
# ------
3-
function copy_input(::typeof(lq_full), A::AbstractMatrix)
4-
return copy!(similar(A, float(eltype(A))), A)
5-
end
6-
function copy_input(::typeof(lq_compact), A::AbstractMatrix)
7-
return copy!(similar(A, float(eltype(A))), A)
8-
end
9-
function copy_input(::typeof(lq_null), A::AbstractMatrix)
10-
return copy!(similar(A, float(eltype(A))), A)
3+
for f in (:lq_full, :lq_compact, :lq_null)
4+
@eval function copy_input(::typeof($f), A::AbstractMatrix)
5+
return copy!(similar(A, float(eltype(A))), A)
6+
end
7+
@eval copy_input(::typeof($f), A::Diagonal) = copy(A)
118
end
129

1310
function check_input(::typeof(lq_full!), A::AbstractMatrix, LQ, ::AbstractAlgorithm)
@@ -40,6 +37,28 @@ function check_input(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgo
4037
return nothing
4138
end
4239

40+
function check_input(::typeof(lq_full!), A::AbstractMatrix, (L, Q), ::DiagonalAlgorithm)
41+
m, n = size(A)
42+
@assert m == n && isdiag(A)
43+
@assert Q isa Diagonal && L isa Diagonal
44+
isempty(L) || @check_size(L, (m, n))
45+
@check_scalar(L, A)
46+
@check_size(Q, (n, n))
47+
@check_scalar(Q, A)
48+
return nothing
49+
end
50+
function check_input(::typeof(lq_compact!), A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
51+
return check_input(lq_full!, A, LQ, alg)
52+
end
53+
function check_input(::typeof(lq_null!), A::AbstractMatrix, N, ::DiagonalAlgorithm)
54+
m, n = size(A)
55+
@assert m == n && isdiag(A)
56+
@assert N isa AbstractMatrix
57+
@check_size(N, (0, m))
58+
@check_scalar(N, A)
59+
return nothing
60+
end
61+
4362
# Outputs
4463
# -------
4564
function initialize_output(::typeof(lq_full!), A::AbstractMatrix, ::AbstractAlgorithm)
@@ -62,44 +81,69 @@ function initialize_output(::typeof(lq_null!), A::AbstractMatrix, ::AbstractAlgo
6281
return Nᴴ
6382
end
6483

84+
for f! in (:lq_full!, :lq_compact!)
85+
@eval function initialize_output(::typeof($f!), A::AbstractMatrix, ::DiagonalAlgorithm)
86+
return similar(A), A
87+
end
88+
end
89+
6590
# Implementation
6691
# --------------
67-
# actual implementation
6892
function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
6993
check_input(lq_full!, A, LQ, alg)
7094
L, Q = LQ
7195
_lapack_lq!(A, L, Q; alg.kwargs...)
7296
return L, Q
7397
end
74-
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
75-
check_input(lq_full!, A, LQ, alg)
76-
L, Q = LQ
77-
lq_via_qr!(A, L, Q, alg.qr_alg)
78-
return L, Q
79-
end
8098
function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ)
8199
check_input(lq_compact!, A, LQ, alg)
82100
L, Q = LQ
83101
_lapack_lq!(A, L, Q; alg.kwargs...)
84102
return L, Q
85103
end
104+
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
105+
check_input(lq_null!, A, Nᴴ, alg)
106+
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
107+
return Nᴴ
108+
end
109+
110+
function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
111+
check_input(lq_full!, A, LQ, alg)
112+
L, Q = LQ
113+
lq_via_qr!(A, L, Q, alg.qr_alg)
114+
return L, Q
115+
end
86116
function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR)
87117
check_input(lq_compact!, A, LQ, alg)
88118
L, Q = LQ
89119
lq_via_qr!(A, L, Q, alg.qr_alg)
90120
return L, Q
91121
end
92-
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ)
93-
check_input(lq_null!, A, Nᴴ, alg)
94-
_lapack_lq_null!(A, Nᴴ; alg.kwargs...)
95-
return Nᴴ
96-
end
97122
function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR)
98123
check_input(lq_null!, A, Nᴴ, alg)
99124
lq_null_via_qr!(A, Nᴴ, alg.qr_alg)
100125
return Nᴴ
101126
end
102127

128+
function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
129+
check_input(lq_full!, A, LQ, alg)
130+
L, Q = LQ
131+
_diagonal_lq!(A, L, Q; alg.kwargs...)
132+
return L, Q
133+
end
134+
function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm)
135+
check_input(lq_compact!, A, LQ, alg)
136+
L, Q = LQ
137+
_diagonal_lq!(A, L, Q; alg.kwargs...)
138+
return L, Q
139+
end
140+
function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm)
141+
check_input(lq_null!, A, N, alg)
142+
return _diagonal_lq_null!(A, N; alg.kwargs...)
143+
end
144+
145+
# LAPACK logic
146+
# ------------
103147
function _lapack_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
104148
positive=false,
105149
pivoted=false,
@@ -177,6 +221,7 @@ function _lapack_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix;
177221
end
178222

179223
# LQ via transposition and QR
224+
# ---------------------------
180225
function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
181226
qr_alg::AbstractAlgorithm)
182227
m, n = size(A)
@@ -203,3 +248,23 @@ function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractA
203248
!isempty(N) && adjoint!(N, Nt)
204249
return N
205250
end
251+
252+
# Diagonal logic
253+
# --------------
254+
function _diagonal_lq!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
255+
positive::Bool=false)
256+
# note: Ad and Qd might share memory here so order of operations is important
257+
Ad = diagview(A)
258+
Ld = diagview(L)
259+
Qd = diagview(Q)
260+
if positive
261+
@. Ld = abs(Ad)
262+
@. Qd = sign_safe(Ad)
263+
else
264+
Ld .= Ad
265+
one!(Q)
266+
end
267+
return L, Q
268+
end
269+
270+
_diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool=false) = N

0 commit comments

Comments
 (0)