Skip to content

Commit 2c04cbb

Browse files
authored
More matrix functions, broadcasting, constructors (#43)
1 parent b49bc0a commit 2c04cbb

File tree

9 files changed

+567
-49
lines changed

9 files changed

+567
-49
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
matrix:
1919
pkg:
2020
- 'BlockSparseArrays'
21+
- 'KroneckerArrays'
2122
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"
2223
with:
2324
localregistry: "https://github.com/ITensor/ITensorRegistry.git"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.17"
4+
version = "0.3.18"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/abstractdiagonalarray/abstractdiagonalarray.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,30 @@ abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end
44
const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2}
55
const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1}
66

7+
# Define for type stability, for some reason the generic versions
8+
# in SparseArraysBase.jl is not type stable.
9+
# TODO: Investigate type stability of `iszero` in SparseArraysBase.jl.
10+
function Base.iszero(a::AbstractDiagonalArray)
11+
return iszero(diagview(a))
12+
end
13+
14+
using FillArrays: AbstractFill, getindex_value
15+
using LinearAlgebra: norm
16+
# TODO: `_norm` works around:
17+
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
18+
# Change back to `norm` when that is fixed.
19+
_norm(a, p::Int=2) = norm(a, p)
20+
function _norm(a::AbstractFill, p::Int=2)
21+
nrm1 = norm(getindex_value(a))
22+
return (length(a))^(1/oftype(nrm1, p)) * nrm1
23+
end
24+
function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2)
25+
# TODO: `_norm` works around:
26+
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
27+
# Change back to `norm` when that is fixed.
28+
return _norm(diagview(a), p)
29+
end
30+
731
using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric
832
LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a)
933
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number})

src/abstractdiagonalarray/diagonalarraydiaginterface.jl

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ function SparseArraysBase.getstoredindex(
4343
# allequal(I) || error("Not a diagonal index.")
4444
return getdiagindex(a, first(I))
4545
end
46+
function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any,0})
47+
return getdiagindex(a, 1)
48+
end
4649
function SparseArraysBase.setstoredindex!(
4750
a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N}
4851
) where {N}
@@ -52,6 +55,10 @@ function SparseArraysBase.setstoredindex!(
5255
setdiagindex!(a, value, first(I))
5356
return a
5457
end
58+
function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any,0}, value)
59+
setdiagindex!(a, value, 1)
60+
return a
61+
end
5562
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray)
5663
return diagindices(a)
5764
end
@@ -99,25 +106,3 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}
99106
copyto!(diagview(dest), broadcasted_diagview(bc))
100107
return dest
101108
end
102-
103-
## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
104-
105-
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
106-
## return a[StorageIndex(i)]
107-
## end
108-
109-
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
110-
## a[StorageIndex(i)] = value
111-
## return a
112-
## end
113-
114-
## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))
115-
116-
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
117-
## return a[StorageIndices(i)]
118-
## end
119-
120-
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
121-
## a[StorageIndices(i)] = value
122-
## return a
123-
## end

src/diaginterface/diaginterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ function setdiagindex!(a::AbstractArray, v, i::Integer)
9797
end
9898

9999
function getdiagindices(a::AbstractArray, I)
100+
# TODO: Should this be a view?
100101
return @view diagview(a)[I]
101102
end
102103

src/diagonalarray/diagonalarray.jl

Lines changed: 208 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,112 @@
11
using FillArrays: Zeros
22
using SparseArraysBase: Unstored, unstored
33

4-
function _DiagonalArray end
4+
diaglength_from_shape(sz::Tuple{Integer,Vararg{Integer}}) = minimum(sz)
5+
function diaglength_from_shape(
6+
sz::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
7+
)
8+
return minimum(length, sz)
9+
end
10+
diaglength_from_shape(sz::Tuple{}) = 1
511

6-
struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <:
12+
struct DiagonalArray{T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} <:
713
AbstractDiagonalArray{T,N}
8-
diag::Diag
9-
unstored::Unstored
10-
global @inline function _DiagonalArray(
11-
diag::Diag, unstored::Unstored
12-
) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}}
13-
length(diag) == minimum(size(unstored)) ||
14+
diag::D
15+
unstored::U
16+
function DiagonalArray{T,N,D,U}(
17+
diag::AbstractVector, unstored::Unstored
18+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
19+
length(diag) == diaglength_from_shape(size(unstored)) ||
1420
throw(ArgumentError("Length of diagonals doesn't match dimensions"))
15-
return new{T,N,Diag,Unstored}(diag, unstored)
21+
return new{T,N,D,U}(diag, parent(unstored))
1622
end
1723
end
1824

1925
SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
2026
Base.size(a::DiagonalArray) = size(unstored(a))
2127
Base.axes(a::DiagonalArray) = axes(unstored(a))
2228

29+
function DiagonalArray{T,N,D}(
30+
diag::D, unstored::Unstored{T,N,U}
31+
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
32+
return DiagonalArray{T,N,D,U}(diag, unstored)
33+
end
34+
function DiagonalArray{T,N}(
35+
diag::D, unstored::Unstored{T,N}
36+
) where {T,N,D<:AbstractVector{T}}
37+
return DiagonalArray{T,N,D}(diag, unstored)
38+
end
39+
function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N}
40+
return DiagonalArray{T,N}(diag, unstored)
41+
end
42+
function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T}
43+
return DiagonalArray{T}(diag, unstored)
44+
end
45+
2346
function DiagonalArray(::UndefInitializer, unstored::Unstored)
24-
return _DiagonalArray(
25-
Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored)
47+
return DiagonalArray(
48+
Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored
49+
)
50+
end
51+
52+
# Indicate we will construct an array just from the shape,
53+
# for example for a Base.OneTo or FillArrays.Ones or Zeros.
54+
# All the elements should be uniquely defined by the input axes.
55+
struct ShapeInitializer end
56+
57+
# This is used to create custom constructors for arrays,
58+
# in this case a generic constructor of a vector from a length.
59+
function construct(vect::Type{<:AbstractVector}, ::ShapeInitializer, len::Integer)
60+
if applicable(vect, len)
61+
return vect(len)
62+
elseif applicable(vect, (Base.OneTo(len),))
63+
return vect((Base.OneTo(len),))
64+
else
65+
error(lazy"Can't construct $(vect) from length.")
66+
end
67+
end
68+
69+
# This helps to support diagonals where the elements are known
70+
# from the types, for example diagonals that are `Zeros` and `Ones`.
71+
function DiagonalArray{T,N,D}(
72+
init::ShapeInitializer, unstored::Unstored
73+
) where {T,N,D<:AbstractVector{T}}
74+
return DiagonalArray{T,N,D}(
75+
construct(D, init, diaglength_from_shape(axes(unstored))), unstored
2676
)
2777
end
2878

29-
# Constructors accepting axes.
79+
# This helps to support diagonals where the elements are known
80+
# from the types, for example diagonals that are `Zeros` and `Ones`.
81+
# These versions use the default unstored type `Zeros{T,N}`.
82+
function DiagonalArray{T,N,D}(
83+
init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
84+
) where {T,N,D<:AbstractVector{T}}
85+
return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax)))
86+
end
87+
function DiagonalArray{T,N,D}(
88+
init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}...
89+
) where {T,N,D<:AbstractVector{T}}
90+
return DiagonalArray{T,N,D}(init, ax)
91+
end
92+
function DiagonalArray{T,N,D}(
93+
init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}}
94+
) where {T,N,D<:AbstractVector{T}}
95+
return DiagonalArray{T,N,D}(init, Base.OneTo.(sz))
96+
end
97+
function DiagonalArray{T,N,D}(
98+
init::ShapeInitializer, sz1::Integer, sz_rest::Integer...
99+
) where {T,N,D<:AbstractVector{T}}
100+
return DiagonalArray{T,N,D}(init, (sz1, sz_rest...))
101+
end
102+
103+
# Constructor from diagonal entries accepting axes.
30104
function DiagonalArray{T,N}(
31105
diag::AbstractVector,
32106
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
33107
) where {T,N}
34108
N == length(ax) || throw(ArgumentError("Wrong number of axes"))
35-
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax))
109+
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax)))
36110
end
37111
function DiagonalArray{T,N}(
38112
diag::AbstractVector,
@@ -97,7 +171,7 @@ function DiagonalArray{T}(
97171
end
98172

99173
function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N}
100-
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims))
174+
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims)))
101175
end
102176

103177
function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N}
@@ -146,7 +220,7 @@ end
146220

147221
# undef
148222
function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
149-
return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims)
223+
return DiagonalArray{T,N}(Vector{T}(undef, diaglength_from_shape(dims)), dims)
150224
end
151225

152226
function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
@@ -162,8 +236,10 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
162236
end
163237

164238
# Axes version
165-
function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N}
166-
return DiagonalArray{T,N}(undef, length.(axes))
239+
function DiagonalArray{T}(
240+
::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}}
241+
) where {T}
242+
return DiagonalArray{T,length(axes)}(undef, length.(axes))
167243
end
168244

169245
function Base.similar(a::DiagonalArray, unstored::Unstored)
@@ -197,3 +273,118 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
197273
# Unlike `permutedims(::Diagonal, perm)`, we copy here.
198274
return DiagonalArray(diagview(a), ax_perm)
199275
end
276+
277+
# Scalar indexing.
278+
using DerivableInterfaces: @interface, interface
279+
one_based_range(r) = false
280+
one_based_range(r::Base.OneTo) = true
281+
one_based_range(r::Base.Slice) = true
282+
function _diag_axes(a::DiagonalArray, I...)
283+
return map(ntuple(identity, ndims(a))) do d
284+
return Base.axes1(axes(a, d)[I[d]])
285+
end
286+
end
287+
# A view that preserves the diagonal structure.
288+
function _view_diag(a::DiagonalArray, I...)
289+
ax = _diag_axes(a, I...)
290+
return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax)
291+
end
292+
function _view_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...)
293+
ax = _diag_axes(a, I1, Irest...)
294+
return DiagonalArray(view(diagview(a), :), ax)
295+
end
296+
# A slice that preserves the diagonal structure.
297+
function _getindex_diag(a::DiagonalArray, I...)
298+
ax = _diag_axes(a, I...)
299+
return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax)
300+
end
301+
function _getindex_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...)
302+
ax = _diag_axes(a, I1, Irest...)
303+
return DiagonalArray(diagview(a)[:], ax)
304+
end
305+
function Base.view(a::DiagonalArray, I...)
306+
I′ = to_indices(a, I)
307+
return if all(one_based_range, I′)
308+
_view_diag(a, I′...)
309+
else
310+
invoke(view, Tuple{AbstractArray,Vararg}, a, I′...)
311+
end
312+
end
313+
function Base.getindex(a::DiagonalArray, I::Int...)
314+
return @interface interface(a) a[I...]
315+
end
316+
function Base.getindex(a::DiagonalArray, I::DiagIndex)
317+
return getdiagindex(a, index(I))
318+
end
319+
function Base.getindex(a::DiagonalArray, I::DiagIndices)
320+
# TODO: Should this be a view?
321+
return @view diagview(a)[indices(I)]
322+
end
323+
function Base.getindex(a::DiagonalArray, I...)
324+
I′ = to_indices(a, I)
325+
return if all(i -> i isa Real, I′)
326+
# Catch scalar indexing case.
327+
@interface interface(a) a[I...]
328+
elseif all(one_based_range, I′)
329+
_getindex_diag(a, I′...)
330+
else
331+
copy(view(a, I′...))
332+
end
333+
end
334+
335+
# Define in order to preserve immutable diagonals such as FillArrays.
336+
function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N}
337+
# TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`:
338+
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112
339+
return a
340+
end
341+
function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
342+
return DiagonalArray{T,N}(diagview(a))
343+
end
344+
function DiagonalArray{T}(a::DiagonalArray) where {T}
345+
return DiagonalArray{T,ndims(a)}(a)
346+
end
347+
function DiagonalArray(a::DiagonalArray)
348+
return DiagonalArray{eltype(a),ndims(a)}(a)
349+
end
350+
function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
351+
return DiagonalArray{T,N}(a)
352+
end
353+
354+
# TODO: These definitions work around this issue:
355+
# https://github.com/JuliaArrays/FillArrays.jl/issues/416
356+
# when the diagonal is a FillArrays.Ones or Zeros.
357+
using Base.Broadcast: Broadcast, broadcast, broadcasted
358+
using FillArrays: AbstractFill, Ones, Zeros
359+
_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a)
360+
_broadcasted(::typeof(identity), a::Ones) = a
361+
_broadcasted(::typeof(identity), a::Zeros) = a
362+
_broadcasted(::typeof(complex), a::Ones) = Ones{complex(eltype(a))}(axes(a))
363+
_broadcasted(::typeof(complex), a::Zeros) = Zeros{complex(eltype(a))}(axes(a))
364+
_broadcasted(elt::Type, a::Ones) = Ones{elt}(axes(a))
365+
_broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a))
366+
_broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
367+
using LinearAlgebra: pinv
368+
_broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
369+
_broadcasted(::typeof(pinv), a::Zeros) = _broadcasted(typeof(inv(zero(eltype(a)))), a)
370+
_broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a)
371+
_broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a)
372+
_broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a)
373+
_broadcasted(::typeof(cbrt), a::Zeros) = _broadcasted(typeof(cbrt(zero(eltype(a)))), a)
374+
_broadcasted(::typeof(exp), a::Zeros) = Ones{typeof(exp(zero(eltype(a))))}(axes(a))
375+
_broadcasted(::typeof(cis), a::Zeros) = Ones{typeof(cis(zero(eltype(a))))}(axes(a))
376+
_broadcasted(::typeof(log), a::Ones) = Zeros{typeof(log(one(eltype(a))))}(axes(a))
377+
_broadcasted(::typeof(cos), a::Zeros) = Ones{typeof(cos(zero(eltype(a))))}(axes(a))
378+
_broadcasted(::typeof(sin), a::Zeros) = _broadcasted(typeof(sin(zero(eltype(a)))), a)
379+
_broadcasted(::typeof(tan), a::Zeros) = _broadcasted(typeof(tan(zero(eltype(a)))), a)
380+
_broadcasted(::typeof(sec), a::Zeros) = Ones{typeof(sec(zero(eltype(a))))}(axes(a))
381+
_broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axes(a))
382+
# Eager version of `_broadcasted`.
383+
_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a))
384+
385+
function Broadcast.broadcasted(
386+
::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag}
387+
) where {F,T,N,Diag<:AbstractFill{T}}
388+
# TODO: Check that `f` preserves zeros?
389+
return DiagonalArray(_broadcasted(f, diagview(a)), axes(a))
390+
end

0 commit comments

Comments
 (0)