Skip to content

Commit 9e66bf0

Browse files
committed
[WIP] More matrix functions, broadcasting, constructors
1 parent b49bc0a commit 9e66bf0

File tree

7 files changed

+303
-28
lines changed

7 files changed

+303
-28
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.17"
4+
version = "0.3.18"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/abstractdiagonalarray/abstractdiagonalarray.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,30 @@ abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end
44
const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2}
55
const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1}
66

7+
# Define for type stability, for some reason the generic versions
8+
# in SparseArraysBase.jl is not type stable.
9+
# TODO: Investigate type stability of `iszero` in SparseArraysBase.jl.
10+
function Base.iszero(a::AbstractDiagonalArray)
11+
return iszero(diagview(a))
12+
end
13+
14+
using FillArrays: AbstractFill, getindex_value
15+
using LinearAlgebra: norm
16+
# TODO: `_norm` works around:
17+
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
18+
# Change back to `norm` when that is fixed.
19+
_norm(a, p::Int=2) = norm(a, p)
20+
function _norm(a::AbstractFill, p::Int=2)
21+
nrm1 = norm(getindex_value(a))
22+
return (length(a))^(1/oftype(nrm1, p)) * nrm1
23+
end
24+
function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2)
25+
# TODO: `_norm` works around:
26+
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
27+
# Change back to `norm` when that is fixed.
28+
return _norm(diagview(a), p)
29+
end
30+
731
using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric
832
LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a)
933
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number})

src/abstractdiagonalarray/diagonalarraydiaginterface.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,25 +99,3 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}
9999
copyto!(diagview(dest), broadcasted_diagview(bc))
100100
return dest
101101
end
102-
103-
## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
104-
105-
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
106-
## return a[StorageIndex(i)]
107-
## end
108-
109-
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
110-
## a[StorageIndex(i)] = value
111-
## return a
112-
## end
113-
114-
## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))
115-
116-
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
117-
## return a[StorageIndices(i)]
118-
## end
119-
120-
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
121-
## a[StorageIndices(i)] = value
122-
## return a
123-
## end

src/diaginterface/diaginterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ function setdiagindex!(a::AbstractArray, v, i::Integer)
9797
end
9898

9999
function getdiagindices(a::AbstractArray, I)
100+
# TODO: Should this be a view?
100101
return @view diagview(a)[I]
101102
end
102103

src/diagonalarray/diagonalarray.jl

Lines changed: 181 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,58 @@ SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
2020
Base.size(a::DiagonalArray) = size(unstored(a))
2121
Base.axes(a::DiagonalArray) = axes(unstored(a))
2222

23+
function DiagonalArray(diag::AbstractVector, unstored::Unstored)
24+
return _DiagonalArray(diag, parent(unstored))
25+
end
2326
function DiagonalArray(::UndefInitializer, unstored::Unstored)
24-
return _DiagonalArray(
25-
Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored)
26-
)
27+
return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored)
28+
end
29+
30+
# This helps to support diagonals where the elements are known
31+
# from the types, for example diagonals that are `Zeros` and `Ones`.
32+
function DiagonalArray{T,N,D,U}(
33+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
34+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
35+
return DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
36+
end
37+
function DiagonalArray{T,N,D,U}(
38+
ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}}
39+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
40+
return DiagonalArray{T,N,D,U}((ax1, ax_rest...))
41+
end
42+
function DiagonalArray{T,N,D,U}(
43+
sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
44+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
45+
return DiagonalArray{T,N,D,U}(Base.OneTo.(sz))
46+
end
47+
function DiagonalArray{T,N,D,U}(
48+
sz1::Integer, sz_rest::Vararg{Integer}
49+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
50+
return DiagonalArray{T,N,D,U}((sz1, sz_rest...))
51+
end
52+
53+
# This helps to support diagonals where the elements are known
54+
# from the types, for example diagonals that are `Zeros` and `Ones`.
55+
# These versions use the default unstored type `Zeros{T,N}`.
56+
function DiagonalArray{T,N,D}(
57+
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
58+
) where {T,N,D<:AbstractVector{T}}
59+
return DiagonalArray{T,N,D,Zeros{T,N}}(ax)
60+
end
61+
function DiagonalArray{T,N,D}(
62+
ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}}
63+
) where {T,N,D<:AbstractVector{T}}
64+
return DiagonalArray{T,N,D,Zeros{T,N}}(ax1, ax_rest...)
65+
end
66+
function DiagonalArray{T,N,D}(
67+
sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
68+
) where {T,N,D<:AbstractVector{T}}
69+
return DiagonalArray{T,N,D,Zeros{T,N}}(sz)
70+
end
71+
function DiagonalArray{T,N,D}(
72+
sz1::Integer, sz_rest::Vararg{Integer}
73+
) where {T,N,D<:AbstractVector{T}}
74+
return DiagonalArray{T,N,D,Zeros{T,N}}(sz1, sz_rest...)
2775
end
2876

2977
# Constructors accepting axes.
@@ -32,7 +80,7 @@ function DiagonalArray{T,N}(
3280
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
3381
) where {T,N}
3482
N == length(ax) || throw(ArgumentError("Wrong number of axes"))
35-
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax))
83+
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax)))
3684
end
3785
function DiagonalArray{T,N}(
3886
diag::AbstractVector,
@@ -97,7 +145,7 @@ function DiagonalArray{T}(
97145
end
98146

99147
function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N}
100-
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims))
148+
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims)))
101149
end
102150

103151
function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N}
@@ -161,6 +209,28 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
161209
return DiagonalArray{T,N}(undef, dims)
162210
end
163211

212+
# 0-dim limit.
213+
function DiagonalArray{T,0,D}(
214+
::UndefInitializer, ax::Tuple{}
215+
) where {T,D<:AbstractVector{T}}
216+
return DiagonalArray{T,0,D}(D(undef, 0), ax)
217+
end
218+
function DiagonalArray{T,0,D}(::UndefInitializer) where {T,D<:AbstractVector{T}}
219+
return DiagonalArray{T,0,D}(undef, ())
220+
end
221+
function DiagonalArray{T,0}(::UndefInitializer, ax::Tuple{}) where {T}
222+
return DiagonalArray{T,0,Vector{T}}(undef, ax)
223+
end
224+
function DiagonalArray{T,0}(::UndefInitializer) where {T}
225+
return DiagonalArray{T,0}(undef, ())
226+
end
227+
function DiagonalArray{T}(::UndefInitializer, axes::Tuple{}) where {T}
228+
return DiagonalArray{T,0}(undef, ())
229+
end
230+
function DiagonalArray{T}(::UndefInitializer) where {T}
231+
return DiagonalArray{T}(undef, ())
232+
end
233+
164234
# Axes version
165235
function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N}
166236
return DiagonalArray{T,N}(undef, length.(axes))
@@ -197,3 +267,109 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
197267
# Unlike `permutedims(::Diagonal, perm)`, we copy here.
198268
return DiagonalArray(diagview(a), ax_perm)
199269
end
270+
271+
# Scalar indexing.
272+
using DerivableInterfaces: @interface, interface
273+
one_based_range(r) = false
274+
one_based_range(r::Base.OneTo) = true
275+
one_based_range(r::Base.Slice) = true
276+
function _diag_axes(a::DiagonalArray, I...)
277+
return map(ntuple(identity, ndims(a))) do d
278+
return Base.axes1(axes(a, d)[I[d]])
279+
end
280+
end
281+
# A view that preserves the diagonal structure.
282+
function _view_diag(a::DiagonalArray, I...)
283+
ax = _diag_axes(a, I...)
284+
return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax)
285+
end
286+
# A slice that preserves the diagonal structure.
287+
function _getindex_diag(a::DiagonalArray, I...)
288+
ax = _diag_axes(a, I...)
289+
return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax)
290+
end
291+
function Base.view(a::DiagonalArray, I...)
292+
I′ = to_indices(a, I)
293+
return if all(one_based_range, I′)
294+
_view_diag(a, I′...)
295+
else
296+
invoke(view, Tuple{AbstractArray,Vararg}, a, I′...)
297+
end
298+
end
299+
function Base.getindex(a::DiagonalArray, I::Int...)
300+
return @interface interface(a) a[I...]
301+
end
302+
function Base.getindex(a::DiagonalArray, I::DiagIndex)
303+
return getdiagindex(a, index(I))
304+
end
305+
function Base.getindex(a::DiagonalArray, I::DiagIndices)
306+
# TODO: Should this be a view?
307+
return @view diagview(a)[indices(I)]
308+
end
309+
function Base.getindex(a::DiagonalArray, I...)
310+
I′ = to_indices(a, I)
311+
return if all(i -> i isa Real, I′)
312+
# Catch scalar indexing case.
313+
@interface interface(a) a[I...]
314+
elseif all(one_based_range, I′)
315+
_getindex_diag(a, I′...)
316+
else
317+
copy(view(a, I′...))
318+
end
319+
end
320+
321+
# Define in order to preserve immutable diagonals such as FillArrays.
322+
function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N}
323+
# TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`:
324+
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112
325+
return a
326+
end
327+
function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
328+
return DiagonalArray{T,N}(diagview(a))
329+
end
330+
function DiagonalArray{T}(a::DiagonalArray) where {T}
331+
return DiagonalArray{T,ndims(a)}(a)
332+
end
333+
function DiagonalArray(a::DiagonalArray)
334+
return DiagonalArray{eltype(a),ndims(a)}(a)
335+
end
336+
function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
337+
return DiagonalArray{T,N}(a)
338+
end
339+
340+
# TODO: These definitions work around this issue:
341+
# https://github.com/JuliaArrays/FillArrays.jl/issues/416
342+
# when the diagonal is a FillArrays.Ones or Zeros.
343+
using Base.Broadcast: Broadcast, broadcast, broadcasted
344+
using FillArrays: AbstractFill, Ones, Zeros
345+
_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a)
346+
_broadcasted(::typeof(identity), a::Ones) = a
347+
_broadcasted(::typeof(identity), a::Zeros) = a
348+
_broadcasted(::typeof(complex), a::Ones) = Ones{complex(eltype(a))}(axes(a))
349+
_broadcasted(::typeof(complex), a::Zeros) = Zeros{complex(eltype(a))}(axes(a))
350+
_broadcasted(elt::Type, a::Ones) = Ones{elt}(axes(a))
351+
_broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a))
352+
_broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
353+
using LinearAlgebra: pinv
354+
_broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
355+
_broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a)
356+
_broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a)
357+
_broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a)
358+
_broadcasted(::typeof(cbrt), a::Zeros) = _broadcasted(typeof(cbrt(zero(eltype(a)))), a)
359+
_broadcasted(::typeof(exp), a::Zeros) = Ones{typeof(exp(zero(eltype(a))))}(axes(a))
360+
_broadcasted(::typeof(cis), a::Zeros) = Ones{typeof(cis(zero(eltype(a))))}(axes(a))
361+
_broadcasted(::typeof(log), a::Ones) = Zeros{typeof(log(one(eltype(a))))}(axes(a))
362+
_broadcasted(::typeof(cos), a::Zeros) = Ones{typeof(cos(zero(eltype(a))))}(axes(a))
363+
_broadcasted(::typeof(sin), a::Zeros) = _broadcasted(typeof(sin(zero(eltype(a)))), a)
364+
_broadcasted(::typeof(tan), a::Zeros) = _broadcasted(typeof(tan(zero(eltype(a)))), a)
365+
_broadcasted(::typeof(sec), a::Zeros) = Ones{typeof(sec(zero(eltype(a))))}(axes(a))
366+
_broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axes(a))
367+
# Eager version of `_broadcasted`.
368+
_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a))
369+
370+
function Broadcast.broadcasted(
371+
::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag}
372+
) where {F,T,N,Diag<:AbstractFill{T}}
373+
# TODO: Check that `f` preserves zeros?
374+
return DiagonalArray(_broadcasted(f, diagview(a)), axes(a))
375+
end

src/diagonalarray/diagonalmatrix.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,91 @@ function LinearAlgebra.mul!(
5858
d_dest .= d1 .* d2 .* α .+ d_dest .* β
5959
return a_dest
6060
end
61+
62+
# Adapted from https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L866-L928.
63+
function LinearAlgebra.tr(a::DiagonalMatrix)
64+
checksquare(a)
65+
# TODO: Define as `sum(tr, diagview(a))` like LinearAlgebra.jl?
66+
return sum(diagview(a))
67+
end
68+
# TODO: Special case for FillArrays diagonals.
69+
function LinearAlgebra.det(a::DiagonalMatrix)
70+
checksquare(a)
71+
# TODO: Define as `prod(det, diagview(a))` like LinearAlgebra.jl?
72+
return prod(diagview(a))
73+
end
74+
# TODO: Special case for FillArrays diagonals.
75+
function LinearAlgebra.logabsdet(a::DiagonalMatrix)
76+
checksquare(a)
77+
return mapreduce(((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), diagview(a)) do x
78+
return (log(abs(x)), sign(x))
79+
end
80+
end
81+
# TODO: Special case for FillArrays diagonals.
82+
function LinearAlgebra.logdet(a::DiagonalMatrix{<:Complex})
83+
checksquare(a)
84+
z = sum(log, diagview(a))
85+
return complex(real(z), rem2pi(imag(z), RoundNearest))
86+
end
87+
88+
# Matrix functions
89+
for f in (
90+
:exp,
91+
:cis,
92+
:log,
93+
:sqrt,
94+
:cos,
95+
:sin,
96+
:tan,
97+
:csc,
98+
:sec,
99+
:cot,
100+
:cosh,
101+
:sinh,
102+
:tanh,
103+
:csch,
104+
:sech,
105+
:coth,
106+
:acos,
107+
:asin,
108+
:atan,
109+
:acsc,
110+
:asec,
111+
:acot,
112+
:acosh,
113+
:asinh,
114+
:atanh,
115+
:acsch,
116+
:asech,
117+
:acoth,
118+
)
119+
@eval begin
120+
function Base.$f(a::DiagonalMatrix)
121+
checksquare(a)
122+
return DiagonalMatrix(_broadcast($f, diagview(a)), axes(a))
123+
end
124+
end
125+
end
126+
127+
# Cube root of a real-valued diagonal matrix
128+
function Base.cbrt(a::DiagonalMatrix{<:Real})
129+
checksquare(a)
130+
return DiagonalMatrix(_broadcast(cbrt, diagview(a)), axes(a))
131+
end
132+
133+
function LinearAlgebra.inv(a::DiagonalMatrix)
134+
checksquare(a)
135+
# `DiagonalArrays._broadcast` works around issues like https://github.com/JuliaArrays/FillArrays.jl/issues/416
136+
# when the diagonal is a FillArray or similar lazy array.
137+
d⁻¹ = _broadcast(inv, diagview(a))
138+
any(isinf, d⁻¹) && error("Singular Exception")
139+
return DiagonalMatrix(d⁻¹, axes(a))
140+
end
141+
142+
# TODO: Support `atol` and `rtol` keyword arguments:
143+
# https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.pinv
144+
using LinearAlgebra: pinv
145+
function LinearAlgebra.pinv(a::DiagonalMatrix)
146+
checksquare(a)
147+
return DiagonalMatrix(_broadcast(pinv, diagview(a)), axes(a))
148+
end

src/dual.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
# TODO: Define `TensorProducts.dual`.
22
dual(x) = x
33
issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2)))
4+
# Like `LinearAlgebra.checksquare` but based on `DiagonalArrays.issquare`,
5+
# which checks the axes and allows customizing to check that the
6+
# codomain is the dual of the domain.
7+
# Returns the codomain if the check passes.
8+
function checksquare(a::AbstractMatrix)
9+
issquare(a) || throw(DimensionMismatch(lazy"matrix is not square: axes are $(axes(a))"))
10+
return axes(a, 1)
11+
end

0 commit comments

Comments
 (0)