Skip to content

Commit 6c34b1a

Browse files
authored
Add MatrixAlgebraKit factorizations for (scaled) deltas (#40)
1 parent f6a66da commit 6c34b1a

File tree

10 files changed

+580
-5
lines changed

10 files changed

+580
-5
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.15"
4+
version = "0.3.16"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -11,11 +11,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1212
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1313

14+
[weakdeps]
15+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
16+
17+
[extensions]
18+
DiagonalArraysMatrixAlgebraKitExt = "MatrixAlgebraKit"
19+
1420
[compat]
1521
ArrayLayouts = "1.10.4"
1622
DerivableInterfaces = "0.5.5"
1723
FillArrays = "1.13.0"
1824
LinearAlgebra = "1.10.0"
1925
MapBroadcast = "0.1.10"
26+
MatrixAlgebraKit = "0.2"
2027
SparseArraysBase = "0.7.2"
2128
julia = "1.10"
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
module DiagonalArraysMatrixAlgebraKitExt
2+
3+
using DiagonalArrays:
4+
AbstractDiagonalMatrix,
5+
DeltaMatrix,
6+
DiagonalMatrix,
7+
ScaledDeltaMatrix,
8+
δ,
9+
diagview,
10+
dual,
11+
issquare
12+
using LinearAlgebra: LinearAlgebra, isdiag, ishermitian
13+
using MatrixAlgebraKit:
14+
MatrixAlgebraKit,
15+
AbstractAlgorithm,
16+
check_input,
17+
default_qr_algorithm,
18+
eig_full,
19+
eig_full!,
20+
eig_vals,
21+
eig_vals!,
22+
eigh_full,
23+
eigh_full!,
24+
eigh_vals,
25+
eigh_vals!,
26+
left_null,
27+
left_null!,
28+
left_orth,
29+
left_orth!,
30+
left_polar,
31+
left_polar!,
32+
lq_compact,
33+
lq_compact!,
34+
lq_full,
35+
lq_full!,
36+
qr_compact,
37+
qr_compact!,
38+
qr_full,
39+
qr_full!,
40+
right_null,
41+
right_null!,
42+
right_orth,
43+
right_orth!,
44+
right_polar,
45+
right_polar!,
46+
svd_compact,
47+
svd_compact!,
48+
svd_full,
49+
svd_full!,
50+
svd_vals,
51+
svd_vals!
52+
53+
abstract type AbstractDiagonalAlgorithm <: AbstractAlgorithm end
54+
55+
struct DeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm
56+
kwargs::KWargs
57+
end
58+
DeltaAlgorithm(; kwargs...) = DeltaAlgorithm((; kwargs...))
59+
60+
struct ScaledDeltaAlgorithm{KWargs<:NamedTuple} <: AbstractDiagonalAlgorithm
61+
kwargs::KWargs
62+
end
63+
ScaledDeltaAlgorithm(; kwargs...) = ScaledDeltaAlgorithm((; kwargs...))
64+
65+
for f in [
66+
:eig_full,
67+
:eig_vals,
68+
:eigh_full,
69+
:eigh_vals,
70+
:qr_compact,
71+
:qr_full,
72+
:left_null,
73+
:left_orth,
74+
:left_polar,
75+
:lq_compact,
76+
:lq_full,
77+
:right_null,
78+
:right_orth,
79+
:right_polar,
80+
:svd_compact,
81+
:svd_full,
82+
:svd_vals,
83+
]
84+
@eval begin
85+
MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractDiagonalMatrix) = copy(a)
86+
end
87+
end
88+
89+
for f in [
90+
:default_eig_algorithm,
91+
:default_eigh_algorithm,
92+
:default_lq_algorithm,
93+
:default_qr_algorithm,
94+
:default_polar_algorithm,
95+
:default_svd_algorithm,
96+
]
97+
@eval begin
98+
function MatrixAlgebraKit.$f(::Type{<:DeltaMatrix}; kwargs...)
99+
return DeltaAlgorithm(; kwargs...)
100+
end
101+
function MatrixAlgebraKit.$f(::Type{<:ScaledDeltaMatrix}; kwargs...)
102+
return ScaledDeltaAlgorithm(; kwargs...)
103+
end
104+
end
105+
end
106+
107+
for f in [
108+
:eig_full!,
109+
:eig_vals!,
110+
:eigh_full!,
111+
:eigh_vals!,
112+
:left_null!,
113+
:left_orth!,
114+
:left_polar!,
115+
:lq_compact!,
116+
:lq_full!,
117+
:qr_compact!,
118+
:qr_full!,
119+
:right_null!,
120+
:right_orth!,
121+
:right_polar!,
122+
:svd_compact!,
123+
:svd_full!,
124+
:svd_vals!,
125+
]
126+
for Alg in [:ScaledDeltaAlgorithm, :DeltaAlgorithm]
127+
@eval begin
128+
function MatrixAlgebraKit.initialize_output(::typeof($f), a, alg::$Alg)
129+
return nothing
130+
end
131+
end
132+
end
133+
end
134+
135+
for f in [
136+
:left_null!,
137+
:left_orth!,
138+
:left_polar!,
139+
:lq_compact!,
140+
:lq_full!,
141+
:qr_compact!,
142+
:qr_full!,
143+
:right_null!,
144+
:right_orth!,
145+
:right_polar!,
146+
:svd_compact!,
147+
:svd_full!,
148+
:svd_vals!,
149+
]
150+
@eval begin
151+
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::DeltaAlgorithm)
152+
@assert size(a, 1) == size(a, 2)
153+
@assert isdiag(a)
154+
@assert all(isone, diagview(a))
155+
return nothing
156+
end
157+
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::ScaledDeltaAlgorithm)
158+
@assert size(a, 1) == size(a, 2)
159+
@assert isdiag(a)
160+
@assert allequal(diagview(a))
161+
return nothing
162+
end
163+
end
164+
end
165+
for f in [:eig_full!, :eig_vals!, :eigh_full!, :eigh_vals!]
166+
@eval begin
167+
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::DeltaAlgorithm)
168+
@assert issquare(a)
169+
@assert isdiag(a)
170+
@assert all(isone, diagview(a))
171+
return nothing
172+
end
173+
function MatrixAlgebraKit.check_input(::typeof($f), a, F, alg::ScaledDeltaAlgorithm)
174+
@assert issquare(a)
175+
@assert isdiag(a)
176+
@assert allequal(diagview(a))
177+
return nothing
178+
end
179+
end
180+
end
181+
182+
# eig
183+
for Alg in [:DeltaAlgorithm, :ScaledDeltaAlgorithm]
184+
@eval begin
185+
function MatrixAlgebraKit.eig_full!(a, F, alg::$Alg)
186+
check_input(eig_full!, a, F, alg)
187+
d = complex(a)
188+
v = δ(complex(eltype(a)), axes(a))
189+
return (d, v)
190+
end
191+
function MatrixAlgebraKit.eigh_full!(a, F, alg::$Alg)
192+
check_input(eigh_full!, a, F, alg)
193+
ishermitian(a) || throw(ArgumentError("Matrix must be Hermitian"))
194+
d = real(a)
195+
v = δ(eltype(a), axes(a))
196+
return (d, v)
197+
end
198+
function MatrixAlgebraKit.eig_vals!(a, F, alg::$Alg)
199+
check_input(eig_vals!, a, F, alg)
200+
return complex(diagview(a))
201+
end
202+
function MatrixAlgebraKit.eigh_vals!(a, F, alg::$Alg)
203+
check_input(eigh_vals!, a, F, alg)
204+
return real(diagview(a))
205+
end
206+
end
207+
end
208+
209+
# svd
210+
for f in [:svd_compact!, :svd_full!]
211+
@eval begin
212+
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
213+
check_input($f, a, F, alg)
214+
u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
215+
s = real(a)
216+
v = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2)))
217+
return (u, s, v)
218+
end
219+
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
220+
check_input($f, a, F, alg)
221+
diagvalue = only(unique(diagview(a)))
222+
u = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
223+
s = abs(diagvalue) * δ(Bool, axes(a))
224+
# Sign is applied arbitarily to `v`, alternatively
225+
# we could apply it to `u`.
226+
v = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2)))
227+
return (u, s, v)
228+
end
229+
end
230+
end
231+
function MatrixAlgebraKit.svd_vals!(a, F, alg::DeltaAlgorithm)
232+
check_input(svd_vals!, a, F, alg)
233+
# Using `real` instead of `abs.` helps to preserve `Ones`.
234+
return real(diagview(a))
235+
end
236+
function MatrixAlgebraKit.svd_vals!(a, F, alg::ScaledDeltaAlgorithm)
237+
check_input(svd_vals!, a, F, alg)
238+
return abs.(diagview(a))
239+
end
240+
241+
# orth
242+
# left_orth is implicitly defined by defining backends like
243+
# qr_compact and left_polar.
244+
for f in [:left_polar!, :qr_compact!, :qr_full!]
245+
@eval begin
246+
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
247+
check_input($f, a, F, alg)
248+
q = δ(eltype(a), (axes(a, 1), dual(axes(a, 1))))
249+
r = copy(a)
250+
return (q, r)
251+
end
252+
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
253+
check_input($f, a, F, alg)
254+
diagvalue = only(unique(diagview(a)))
255+
q = sign(diagvalue) * δ(Bool, (axes(a, 1), dual(axes(a, 1))))
256+
# We're a bit pessimistic about the element type for type stability,
257+
# since in the future we might provide the option to do non-positive QR.
258+
r = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a))
259+
return (q, r)
260+
end
261+
end
262+
end
263+
# right_orth is implicitly defined by defining backends like
264+
# lq_compact and right_polar.
265+
for f in [:right_polar!, :lq_compact!, :lq_full!]
266+
@eval begin
267+
function MatrixAlgebraKit.$f(a, F, alg::DeltaAlgorithm)
268+
check_input($f, a, F, alg)
269+
l = copy(a)
270+
q = δ(eltype(a), (dual(axes(a, 2)), axes(a, 2)))
271+
return (l, q)
272+
end
273+
function MatrixAlgebraKit.$f(a, F, alg::ScaledDeltaAlgorithm)
274+
check_input($f, a, F, alg)
275+
diagvalue = only(unique(diagview(a)))
276+
# We're a bit pessimistic about the element type for type stability,
277+
# since in the future we might provide the option to do non-positive LQ.
278+
l = eltype(a)(abs(diagvalue)) * δ(Bool, axes(a))
279+
q = sign(diagvalue) * δ(Bool, (dual(axes(a, 2)), axes(a, 2)))
280+
return (l, q)
281+
end
282+
end
283+
end
284+
285+
# null
286+
for T in [:DeltaMatrix, :ScaledDeltaMatrix]
287+
@eval begin
288+
# TODO: Right now we can't overload `left_null!` on an algorithm,
289+
# make a PR to MatrixAlgebraKit.jl to allow that.
290+
function MatrixAlgebraKit.left_null!(a::$T, F)
291+
check_input(left_null!, a, F, default_qr_algorithm(a))
292+
return error("Not implemented.")
293+
end
294+
# TODO: Right now we can't overload `right_null!` on an algorithm,
295+
# make a PR to MatrixAlgebraKit.jl to allow that.
296+
function MatrixAlgebraKit.right_null!(a::$T, F)
297+
check_input(right_null!, a, F, default_qr_algorithm(a))
298+
return error("Not implemented.")
299+
end
300+
end
301+
end
302+
303+
end

src/DiagonalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module DiagonalArrays
22

3+
include("dual.jl")
34
include("diaginterface/diaginterface.jl")
45
include("diaginterface/diagindex.jl")
56
include("diaginterface/diagindices.jl")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
11
using SparseArraysBase: AbstractSparseArray
22

33
abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end
4+
const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2}
5+
const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1}
6+
7+
using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric
8+
LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a)
9+
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number})
10+
return issquare(a) && isreal(diagview(a))
11+
end
12+
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix)
13+
return issquare(a) && all(ishermitian, diagview(a))
14+
end
15+
LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix{<:Number}) = issquare(a)
16+
function LinearAlgebra.issymmetric(a::AbstractDiagonalMatrix)
17+
return issquare(a) && all(issymmetric, diagview(a))
18+
end
19+
function LinearAlgebra.isposdef(a::AbstractDiagonalMatrix)
20+
return issquare(a) && all(isposdef, diagview(a))
21+
end

src/diagonalarray/diagonalarray.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,13 @@ function Base.similar(a::DiagonalArray, unstored::Unstored)
170170
return DiagonalArray(undef, unstored)
171171
end
172172

173-
# This definition is helpful for immutable diagonals
173+
# These definitions are helpful for immutable diagonals
174174
# such as FillArrays.
175-
Base.copy(a::DiagonalArray) = DiagonalArray(copy(diagview(a)), axes(a))
175+
for f in [:complex, :copy, :imag, :real]
176+
@eval begin
177+
Base.$f(a::DiagonalArray) = DiagonalArray($f(diagview(a)), axes(a))
178+
end
179+
end
176180

177181
# DiagonalArrays interface.
178182
diagview(a::DiagonalArray) = a.diag

src/diagonalarray/diagonalmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LinearAlgebra: LinearAlgebra
88

99
function mul_diagviews(a1, a2)
1010
# TODO: Compare that duals are equal, or define a function to overload.
11-
axes(a1, 2) == axes(a2, 1) || throw(
11+
dual(axes(a1, 2)) == axes(a2, 1) || throw(
1212
DimensionMismatch(
1313
lazy"Incompatible dimensions for multiplication: $(axes(a1)) and $(axes(a2))"
1414
),

src/dual.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO: Define `TensorProducts.dual`.
2+
dual(x) = x
3+
issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2)))

0 commit comments

Comments
 (0)