11using FillArrays: Zeros
22using SparseArraysBase: Unstored, unstored
33
4- function _DiagonalArray end
4+ diaglength_from_shape (sz:: Tuple{Integer,Vararg{Integer}} ) = minimum (sz)
5+ function diaglength_from_shape (
6+ sz:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
7+ )
8+ return minimum (length, sz)
9+ end
10+ diaglength_from_shape (sz:: Tuple{} ) = 1
511
6- struct DiagonalArray{T,N,Diag <: AbstractVector{T} ,Unstored <: AbstractArray{T,N} } < :
12+ struct DiagonalArray{T,N,D <: AbstractVector{T} ,U <: AbstractArray{T,N} } < :
713 AbstractDiagonalArray{T,N}
8- diag:: Diag
9- unstored:: Unstored
10- global @inline function _DiagonalArray (
11- diag:: Diag , unstored:: Unstored
12- ) where {T,N,Diag <: AbstractVector{T} ,Unstored <: AbstractArray{T,N} }
13- length (diag) == minimum (size (unstored)) ||
14+ diag:: D
15+ unstored:: U
16+ function DiagonalArray {T,N,D,U} (
17+ diag:: AbstractVector , unstored:: Unstored
18+ ) where {T,N,D <: AbstractVector{T} ,U <: AbstractArray{T,N} }
19+ length (diag) == diaglength_from_shape (size (unstored)) ||
1420 throw (ArgumentError (" Length of diagonals doesn't match dimensions" ))
15- return new {T,N,Diag,Unstored } (diag, unstored)
21+ return new {T,N,D,U } (diag, parent ( unstored) )
1622 end
1723end
1824
1925SparseArraysBase. unstored (a:: DiagonalArray ) = a. unstored
2026Base. size (a:: DiagonalArray ) = size (unstored (a))
2127Base. axes (a:: DiagonalArray ) = axes (unstored (a))
2228
29+ function DiagonalArray {T,N,D} (
30+ diag:: D , unstored:: Unstored{T,N,U}
31+ ) where {T,N,D<: AbstractVector{T} ,U<: AbstractArray{T,N} }
32+ return DiagonalArray {T,N,D,U} (diag, unstored)
33+ end
34+ function DiagonalArray {T,N} (
35+ diag:: D , unstored:: Unstored{T,N}
36+ ) where {T,N,D<: AbstractVector{T} }
37+ return DiagonalArray {T,N,D} (diag, unstored)
38+ end
39+ function DiagonalArray {T} (diag:: AbstractVector{T} , unstored:: Unstored{T,N} ) where {T,N}
40+ return DiagonalArray {T,N} (diag, unstored)
41+ end
42+ function DiagonalArray (diag:: AbstractVector{T} , unstored:: Unstored{T} ) where {T}
43+ return DiagonalArray {T} (diag, unstored)
44+ end
45+
2346function DiagonalArray (:: UndefInitializer , unstored:: Unstored )
24- return _DiagonalArray (
25- Vector {eltype(unstored)} (undef, minimum (size (unstored))), parent (unstored)
47+ return DiagonalArray (
48+ Vector {eltype(unstored)} (undef, diaglength_from_shape (size (unstored))), unstored
49+ )
50+ end
51+
52+ # Indicate we will construct an array just from the shape,
53+ # for example for a Base.OneTo or FillArrays.Ones or Zeros.
54+ # All the elements should be uniquely defined by the input axes.
55+ struct ShapeInitializer end
56+
57+ # This is used to create custom constructors for arrays,
58+ # in this case a generic constructor of a vector from a length.
59+ function construct (vect:: Type{<:AbstractVector} , :: ShapeInitializer , len:: Integer )
60+ if applicable (vect, len)
61+ return vect (len)
62+ elseif applicable (vect, (Base. OneTo (len),))
63+ return vect ((Base. OneTo (len),))
64+ else
65+ error (lazy " Can't construct $(vect) from length." )
66+ end
67+ end
68+
69+ # This helps to support diagonals where the elements are known
70+ # from the types, for example diagonals that are `Zeros` and `Ones`.
71+ function DiagonalArray {T,N,D} (
72+ init:: ShapeInitializer , unstored:: Unstored
73+ ) where {T,N,D<: AbstractVector{T} }
74+ return DiagonalArray {T,N,D} (
75+ construct (D, init, diaglength_from_shape (axes (unstored))), unstored
2676 )
2777end
2878
29- # Constructors accepting axes.
79+ # This helps to support diagonals where the elements are known
80+ # from the types, for example diagonals that are `Zeros` and `Ones`.
81+ # These versions use the default unstored type `Zeros{T,N}`.
82+ function DiagonalArray {T,N,D} (
83+ init:: ShapeInitializer , ax:: Tuple{Vararg{AbstractUnitRange{<:Integer}}}
84+ ) where {T,N,D<: AbstractVector{T} }
85+ return DiagonalArray {T,N,D} (init, Unstored (Zeros {T,N} (ax)))
86+ end
87+ function DiagonalArray {T,N,D} (
88+ init:: ShapeInitializer , ax:: AbstractUnitRange{<:Integer} ...
89+ ) where {T,N,D<: AbstractVector{T} }
90+ return DiagonalArray {T,N,D} (init, ax)
91+ end
92+ function DiagonalArray {T,N,D} (
93+ init:: ShapeInitializer , sz:: Tuple{Integer,Vararg{Integer}}
94+ ) where {T,N,D<: AbstractVector{T} }
95+ return DiagonalArray {T,N,D} (init, Base. OneTo .(sz))
96+ end
97+ function DiagonalArray {T,N,D} (
98+ init:: ShapeInitializer , sz1:: Integer , sz_rest:: Integer...
99+ ) where {T,N,D<: AbstractVector{T} }
100+ return DiagonalArray {T,N,D} (init, (sz1, sz_rest... ))
101+ end
102+
103+ # Constructor from diagonal entries accepting axes.
30104function DiagonalArray {T,N} (
31105 diag:: AbstractVector ,
32106 ax:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ,
33107) where {T,N}
34108 N == length (ax) || throw (ArgumentError (" Wrong number of axes" ))
35- return _DiagonalArray (convert (AbstractVector{T}, diag), Zeros {T} (ax))
109+ return DiagonalArray (convert (AbstractVector{T}, diag), Unstored ( Zeros {T} (ax) ))
36110end
37111function DiagonalArray {T,N} (
38112 diag:: AbstractVector ,
@@ -97,7 +171,7 @@ function DiagonalArray{T}(
97171end
98172
99173function DiagonalArray {T,N} (diag:: AbstractVector , dims:: Dims{N} ) where {T,N}
100- return _DiagonalArray (convert (AbstractVector{T}, diag), Zeros {T} (dims))
174+ return DiagonalArray (convert (AbstractVector{T}, diag), Unstored ( Zeros {T} (dims) ))
101175end
102176
103177function DiagonalArray {T,N} (diag:: AbstractVector , dims:: Vararg{Int,N} ) where {T,N}
146220
147221# undef
148222function DiagonalArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
149- return DiagonalArray {T,N} (Vector {T} (undef, minimum (dims)), dims)
223+ return DiagonalArray {T,N} (Vector {T} (undef, diaglength_from_shape (dims)), dims)
150224end
151225
152226function DiagonalArray {T,N} (:: UndefInitializer , dims:: Vararg{Int,N} ) where {T,N}
@@ -162,8 +236,10 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
162236end
163237
164238# Axes version
165- function DiagonalArray {T} (:: UndefInitializer , axes:: NTuple{N,Base.OneTo{Int}} ) where {T,N}
166- return DiagonalArray {T,N} (undef, length .(axes))
239+ function DiagonalArray {T} (
240+ :: UndefInitializer , axes:: Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}}
241+ ) where {T}
242+ return DiagonalArray {T,length(axes)} (undef, length .(axes))
167243end
168244
169245function Base. similar (a:: DiagonalArray , unstored:: Unstored )
@@ -197,3 +273,118 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
197273 # Unlike `permutedims(::Diagonal, perm)`, we copy here.
198274 return DiagonalArray (diagview (a), ax_perm)
199275end
276+
277+ # Scalar indexing.
278+ using DerivableInterfaces: @interface , interface
279+ one_based_range (r) = false
280+ one_based_range (r:: Base.OneTo ) = true
281+ one_based_range (r:: Base.Slice ) = true
282+ function _diag_axes (a:: DiagonalArray , I... )
283+ return map (ntuple (identity, ndims (a))) do d
284+ return Base. axes1 (axes (a, d)[I[d]])
285+ end
286+ end
287+ # A view that preserves the diagonal structure.
288+ function _view_diag (a:: DiagonalArray , I... )
289+ ax = _diag_axes (a, I... )
290+ return DiagonalArray (view (diagview (a), Base. OneTo (minimum (length, I))), ax)
291+ end
292+ function _view_diag (a:: DiagonalArray , I1:: Base.Slice , Irest:: Base.Slice... )
293+ ax = _diag_axes (a, I1, Irest... )
294+ return DiagonalArray (view (diagview (a), :), ax)
295+ end
296+ # A slice that preserves the diagonal structure.
297+ function _getindex_diag (a:: DiagonalArray , I... )
298+ ax = _diag_axes (a, I... )
299+ return DiagonalArray (diagview (a)[Base. OneTo (minimum (length, I))], ax)
300+ end
301+ function _getindex_diag (a:: DiagonalArray , I1:: Base.Slice , Irest:: Base.Slice... )
302+ ax = _diag_axes (a, I1, Irest... )
303+ return DiagonalArray (diagview (a)[:], ax)
304+ end
305+ function Base. view (a:: DiagonalArray , I... )
306+ I′ = to_indices (a, I)
307+ return if all (one_based_range, I′)
308+ _view_diag (a, I′... )
309+ else
310+ invoke (view, Tuple{AbstractArray,Vararg}, a, I′... )
311+ end
312+ end
313+ function Base. getindex (a:: DiagonalArray , I:: Int... )
314+ return @interface interface (a) a[I... ]
315+ end
316+ function Base. getindex (a:: DiagonalArray , I:: DiagIndex )
317+ return getdiagindex (a, index (I))
318+ end
319+ function Base. getindex (a:: DiagonalArray , I:: DiagIndices )
320+ # TODO : Should this be a view?
321+ return @view diagview (a)[indices (I)]
322+ end
323+ function Base. getindex (a:: DiagonalArray , I... )
324+ I′ = to_indices (a, I)
325+ return if all (i -> i isa Real, I′)
326+ # Catch scalar indexing case.
327+ @interface interface (a) a[I... ]
328+ elseif all (one_based_range, I′)
329+ _getindex_diag (a, I′... )
330+ else
331+ copy (view (a, I′... ))
332+ end
333+ end
334+
335+ # Define in order to preserve immutable diagonals such as FillArrays.
336+ function DiagonalArray {T,N} (a:: DiagonalArray{T,N} ) where {T,N}
337+ # TODO : Should this copy? This matches the design of `LinearAlgebra.Diagonal`:
338+ # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112
339+ return a
340+ end
341+ function DiagonalArray {T,N} (a:: DiagonalArray{<:Any,N} ) where {T,N}
342+ return DiagonalArray {T,N} (diagview (a))
343+ end
344+ function DiagonalArray {T} (a:: DiagonalArray ) where {T}
345+ return DiagonalArray {T,ndims(a)} (a)
346+ end
347+ function DiagonalArray (a:: DiagonalArray )
348+ return DiagonalArray {eltype(a),ndims(a)} (a)
349+ end
350+ function Base. AbstractArray {T,N} (a:: DiagonalArray{<:Any,N} ) where {T,N}
351+ return DiagonalArray {T,N} (a)
352+ end
353+
354+ # TODO : These definitions work around this issue:
355+ # https://github.com/JuliaArrays/FillArrays.jl/issues/416
356+ # when the diagonal is a FillArrays.Ones or Zeros.
357+ using Base. Broadcast: Broadcast, broadcast, broadcasted
358+ using FillArrays: AbstractFill, Ones, Zeros
359+ _broadcasted (f:: F , a:: AbstractArray ) where {F} = broadcasted (f, a)
360+ _broadcasted (:: typeof (identity), a:: Ones ) = a
361+ _broadcasted (:: typeof (identity), a:: Zeros ) = a
362+ _broadcasted (:: typeof (complex), a:: Ones ) = Ones {complex(eltype(a))} (axes (a))
363+ _broadcasted (:: typeof (complex), a:: Zeros ) = Zeros {complex(eltype(a))} (axes (a))
364+ _broadcasted (elt:: Type , a:: Ones ) = Ones {elt} (axes (a))
365+ _broadcasted (elt:: Type , a:: Zeros ) = Zeros {elt} (axes (a))
366+ _broadcasted (:: typeof (inv), a:: Ones ) = _broadcasted (typeof (inv (oneunit (eltype (a)))), a)
367+ using LinearAlgebra: pinv
368+ _broadcasted (:: typeof (pinv), a:: Ones ) = _broadcasted (typeof (inv (oneunit (eltype (a)))), a)
369+ _broadcasted (:: typeof (pinv), a:: Zeros ) = _broadcasted (typeof (inv (zero (eltype (a)))), a)
370+ _broadcasted (:: typeof (sqrt), a:: Ones ) = _broadcasted (typeof (sqrt (one (eltype (a)))), a)
371+ _broadcasted (:: typeof (sqrt), a:: Zeros ) = _broadcasted (typeof (sqrt (zero (eltype (a)))), a)
372+ _broadcasted (:: typeof (cbrt), a:: Ones ) = _broadcasted (typeof (cbrt (one (eltype (a)))), a)
373+ _broadcasted (:: typeof (cbrt), a:: Zeros ) = _broadcasted (typeof (cbrt (zero (eltype (a)))), a)
374+ _broadcasted (:: typeof (exp), a:: Zeros ) = Ones {typeof(exp(zero(eltype(a))))} (axes (a))
375+ _broadcasted (:: typeof (cis), a:: Zeros ) = Ones {typeof(cis(zero(eltype(a))))} (axes (a))
376+ _broadcasted (:: typeof (log), a:: Ones ) = Zeros {typeof(log(one(eltype(a))))} (axes (a))
377+ _broadcasted (:: typeof (cos), a:: Zeros ) = Ones {typeof(cos(zero(eltype(a))))} (axes (a))
378+ _broadcasted (:: typeof (sin), a:: Zeros ) = _broadcasted (typeof (sin (zero (eltype (a)))), a)
379+ _broadcasted (:: typeof (tan), a:: Zeros ) = _broadcasted (typeof (tan (zero (eltype (a)))), a)
380+ _broadcasted (:: typeof (sec), a:: Zeros ) = Ones {typeof(sec(zero(eltype(a))))} (axes (a))
381+ _broadcasted (:: typeof (cosh), a:: Zeros ) = Ones {typeof(cosh(zero(eltype(a))))} (axes (a))
382+ # Eager version of `_broadcasted`.
383+ _broadcast (f:: F , a:: AbstractArray ) where {F} = copy (_broadcasted (f, a))
384+
385+ function Broadcast. broadcasted (
386+ :: DiagonalArrayStyle{N} , f:: F , a:: DiagonalArray{T,N,Diag}
387+ ) where {F,T,N,Diag<: AbstractFill{T} }
388+ # TODO : Check that `f` preserves zeros?
389+ return DiagonalArray (_broadcasted (f, diagview (a)), axes (a))
390+ end
0 commit comments