@@ -20,10 +20,58 @@ SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
2020Base. size (a:: DiagonalArray ) = size (unstored (a))
2121Base. axes (a:: DiagonalArray ) = axes (unstored (a))
2222
23+ function DiagonalArray (diag:: AbstractVector , unstored:: Unstored )
24+ return _DiagonalArray (diag, parent (unstored))
25+ end
2326function DiagonalArray (:: UndefInitializer , unstored:: Unstored )
24- return _DiagonalArray (
25- Vector {eltype(unstored)} (undef, minimum (size (unstored))), parent (unstored)
26- )
27+ return DiagonalArray (Vector {eltype(unstored)} (undef, minimum (size (unstored))), unstored)
28+ end
29+
30+ # This helps to support diagonals where the elements are known
31+ # from the types, for example diagonals that are `Zeros` and `Ones`.
32+ function DiagonalArray {T,N,D,U} (
33+ ax:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
34+ ) where {T,N,D<: AbstractVector{T} ,U<: AbstractArray{T,N} }
35+ return DiagonalArray (D ((Base. OneTo (minimum (length, ax)),)), Unstored (U (ax)))
36+ end
37+ function DiagonalArray {T,N,D,U} (
38+ ax1:: AbstractUnitRange{<:Integer} , ax_rest:: Vararg{AbstractUnitRange{<:Integer}}
39+ ) where {T,N,D<: AbstractVector{T} ,U<: AbstractArray{T,N} }
40+ return DiagonalArray {T,N,D,U} ((ax1, ax_rest... ))
41+ end
42+ function DiagonalArray {T,N,D,U} (
43+ sz:: Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
44+ ) where {T,N,D<: AbstractVector{T} ,U<: AbstractArray{T,N} }
45+ return DiagonalArray {T,N,D,U} (Base. OneTo .(sz))
46+ end
47+ function DiagonalArray {T,N,D,U} (
48+ sz1:: Integer , sz_rest:: Vararg{Integer}
49+ ) where {T,N,D<: AbstractVector{T} ,U<: AbstractArray{T,N} }
50+ return DiagonalArray {T,N,D,U} ((sz1, sz_rest... ))
51+ end
52+
53+ # This helps to support diagonals where the elements are known
54+ # from the types, for example diagonals that are `Zeros` and `Ones`.
55+ # These versions use the default unstored type `Zeros{T,N}`.
56+ function DiagonalArray {T,N,D} (
57+ ax:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
58+ ) where {T,N,D<: AbstractVector{T} }
59+ return DiagonalArray {T,N,D,Zeros{T,N}} (ax)
60+ end
61+ function DiagonalArray {T,N,D} (
62+ ax1:: AbstractUnitRange{<:Integer} , ax_rest:: Vararg{AbstractUnitRange{<:Integer}}
63+ ) where {T,N,D<: AbstractVector{T} }
64+ return DiagonalArray {T,N,D,Zeros{T,N}} (ax1, ax_rest... )
65+ end
66+ function DiagonalArray {T,N,D} (
67+ sz:: Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
68+ ) where {T,N,D<: AbstractVector{T} }
69+ return DiagonalArray {T,N,D,Zeros{T,N}} (sz)
70+ end
71+ function DiagonalArray {T,N,D} (
72+ sz1:: Integer , sz_rest:: Vararg{Integer}
73+ ) where {T,N,D<: AbstractVector{T} }
74+ return DiagonalArray {T,N,D,Zeros{T,N}} (sz1, sz_rest... )
2775end
2876
2977# Constructors accepting axes.
@@ -32,7 +80,7 @@ function DiagonalArray{T,N}(
3280 ax:: Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ,
3381) where {T,N}
3482 N == length (ax) || throw (ArgumentError (" Wrong number of axes" ))
35- return _DiagonalArray (convert (AbstractVector{T}, diag), Zeros {T} (ax))
83+ return DiagonalArray (convert (AbstractVector{T}, diag), Unstored ( Zeros {T} (ax) ))
3684end
3785function DiagonalArray {T,N} (
3886 diag:: AbstractVector ,
@@ -97,7 +145,7 @@ function DiagonalArray{T}(
97145end
98146
99147function DiagonalArray {T,N} (diag:: AbstractVector , dims:: Dims{N} ) where {T,N}
100- return _DiagonalArray (convert (AbstractVector{T}, diag), Zeros {T} (dims))
148+ return DiagonalArray (convert (AbstractVector{T}, diag), Unstored ( Zeros {T} (dims) ))
101149end
102150
103151function DiagonalArray {T,N} (diag:: AbstractVector , dims:: Vararg{Int,N} ) where {T,N}
@@ -161,6 +209,28 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
161209 return DiagonalArray {T,N} (undef, dims)
162210end
163211
212+ # 0-dim limit.
213+ function DiagonalArray {T,0,D} (
214+ :: UndefInitializer , ax:: Tuple{}
215+ ) where {T,D<: AbstractVector{T} }
216+ return DiagonalArray {T,0,D} (D (undef, 0 ), ax)
217+ end
218+ function DiagonalArray {T,0,D} (:: UndefInitializer ) where {T,D<: AbstractVector{T} }
219+ return DiagonalArray {T,0,D} (undef, ())
220+ end
221+ function DiagonalArray {T,0} (:: UndefInitializer , ax:: Tuple{} ) where {T}
222+ return DiagonalArray {T,0,Vector{T}} (undef, ax)
223+ end
224+ function DiagonalArray {T,0} (:: UndefInitializer ) where {T}
225+ return DiagonalArray {T,0} (undef, ())
226+ end
227+ function DiagonalArray {T} (:: UndefInitializer , axes:: Tuple{} ) where {T}
228+ return DiagonalArray {T,0} (undef, ())
229+ end
230+ function DiagonalArray {T} (:: UndefInitializer ) where {T}
231+ return DiagonalArray {T} (undef, ())
232+ end
233+
164234# Axes version
165235function DiagonalArray {T} (:: UndefInitializer , axes:: NTuple{N,Base.OneTo{Int}} ) where {T,N}
166236 return DiagonalArray {T,N} (undef, length .(axes))
@@ -197,3 +267,109 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
197267 # Unlike `permutedims(::Diagonal, perm)`, we copy here.
198268 return DiagonalArray (diagview (a), ax_perm)
199269end
270+
271+ # Scalar indexing.
272+ using DerivableInterfaces: @interface , interface
273+ one_based_range (r) = false
274+ one_based_range (r:: Base.OneTo ) = true
275+ one_based_range (r:: Base.Slice ) = true
276+ function _diag_axes (a:: DiagonalArray , I... )
277+ return map (ntuple (identity, ndims (a))) do d
278+ return Base. axes1 (axes (a, d)[I[d]])
279+ end
280+ end
281+ # A view that preserves the diagonal structure.
282+ function _view_diag (a:: DiagonalArray , I... )
283+ ax = _diag_axes (a, I... )
284+ return DiagonalArray (view (diagview (a), Base. OneTo (minimum (length, I))), ax)
285+ end
286+ # A slice that preserves the diagonal structure.
287+ function _getindex_diag (a:: DiagonalArray , I... )
288+ ax = _diag_axes (a, I... )
289+ return DiagonalArray (diagview (a)[Base. OneTo (minimum (length, I))], ax)
290+ end
291+ function Base. view (a:: DiagonalArray , I... )
292+ I′ = to_indices (a, I)
293+ return if all (one_based_range, I′)
294+ _view_diag (a, I′... )
295+ else
296+ invoke (view, Tuple{AbstractArray,Vararg}, a, I′... )
297+ end
298+ end
299+ function Base. getindex (a:: DiagonalArray , I:: Int... )
300+ return @interface interface (a) a[I... ]
301+ end
302+ function Base. getindex (a:: DiagonalArray , I:: DiagIndex )
303+ return getdiagindex (a, index (I))
304+ end
305+ function Base. getindex (a:: DiagonalArray , I:: DiagIndices )
306+ # TODO : Should this be a view?
307+ return @view diagview (a)[indices (I)]
308+ end
309+ function Base. getindex (a:: DiagonalArray , I... )
310+ I′ = to_indices (a, I)
311+ return if all (i -> i isa Real, I′)
312+ # Catch scalar indexing case.
313+ @interface interface (a) a[I... ]
314+ elseif all (one_based_range, I′)
315+ _getindex_diag (a, I′... )
316+ else
317+ copy (view (a, I′... ))
318+ end
319+ end
320+
321+ # Define in order to preserve immutable diagonals such as FillArrays.
322+ function DiagonalArray {T,N} (a:: DiagonalArray{T,N} ) where {T,N}
323+ # TODO : Should this copy? This matches the design of `LinearAlgebra.Diagonal`:
324+ # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112
325+ return a
326+ end
327+ function DiagonalArray {T,N} (a:: DiagonalArray{<:Any,N} ) where {T,N}
328+ return DiagonalArray {T,N} (diagview (a))
329+ end
330+ function DiagonalArray {T} (a:: DiagonalArray ) where {T}
331+ return DiagonalArray {T,ndims(a)} (a)
332+ end
333+ function DiagonalArray (a:: DiagonalArray )
334+ return DiagonalArray {eltype(a),ndims(a)} (a)
335+ end
336+ function Base. AbstractArray {T,N} (a:: DiagonalArray{<:Any,N} ) where {T,N}
337+ return DiagonalArray {T,N} (a)
338+ end
339+
340+ # TODO : These definitions work around this issue:
341+ # https://github.com/JuliaArrays/FillArrays.jl/issues/416
342+ # when the diagonal is a FillArrays.Ones or Zeros.
343+ using Base. Broadcast: Broadcast, broadcast, broadcasted
344+ using FillArrays: AbstractFill, Ones, Zeros
345+ _broadcasted (f:: F , a:: AbstractArray ) where {F} = broadcasted (f, a)
346+ _broadcasted (:: typeof (identity), a:: Ones ) = a
347+ _broadcasted (:: typeof (identity), a:: Zeros ) = a
348+ _broadcasted (:: typeof (complex), a:: Ones ) = Ones {complex(eltype(a))} (axes (a))
349+ _broadcasted (:: typeof (complex), a:: Zeros ) = Zeros {complex(eltype(a))} (axes (a))
350+ _broadcasted (elt:: Type , a:: Ones ) = Ones {elt} (axes (a))
351+ _broadcasted (elt:: Type , a:: Zeros ) = Zeros {elt} (axes (a))
352+ _broadcasted (:: typeof (inv), a:: Ones ) = _broadcasted (typeof (inv (oneunit (eltype (a)))), a)
353+ using LinearAlgebra: pinv
354+ _broadcasted (:: typeof (pinv), a:: Ones ) = _broadcasted (typeof (inv (oneunit (eltype (a)))), a)
355+ _broadcasted (:: typeof (sqrt), a:: Ones ) = _broadcasted (typeof (sqrt (one (eltype (a)))), a)
356+ _broadcasted (:: typeof (sqrt), a:: Zeros ) = _broadcasted (typeof (sqrt (zero (eltype (a)))), a)
357+ _broadcasted (:: typeof (cbrt), a:: Ones ) = _broadcasted (typeof (cbrt (one (eltype (a)))), a)
358+ _broadcasted (:: typeof (cbrt), a:: Zeros ) = _broadcasted (typeof (cbrt (zero (eltype (a)))), a)
359+ _broadcasted (:: typeof (exp), a:: Zeros ) = Ones {typeof(exp(zero(eltype(a))))} (axes (a))
360+ _broadcasted (:: typeof (cis), a:: Zeros ) = Ones {typeof(cis(zero(eltype(a))))} (axes (a))
361+ _broadcasted (:: typeof (log), a:: Ones ) = Zeros {typeof(log(one(eltype(a))))} (axes (a))
362+ _broadcasted (:: typeof (cos), a:: Zeros ) = Ones {typeof(cos(zero(eltype(a))))} (axes (a))
363+ _broadcasted (:: typeof (sin), a:: Zeros ) = _broadcasted (typeof (sin (zero (eltype (a)))), a)
364+ _broadcasted (:: typeof (tan), a:: Zeros ) = _broadcasted (typeof (tan (zero (eltype (a)))), a)
365+ _broadcasted (:: typeof (sec), a:: Zeros ) = Ones {typeof(sec(zero(eltype(a))))} (axes (a))
366+ _broadcasted (:: typeof (cosh), a:: Zeros ) = Ones {typeof(cosh(zero(eltype(a))))} (axes (a))
367+ # Eager version of `_broadcasted`.
368+ _broadcast (f:: F , a:: AbstractArray ) where {F} = copy (_broadcasted (f, a))
369+
370+ function Broadcast. broadcasted (
371+ :: DiagonalArrayStyle{N} , f:: F , a:: DiagonalArray{T,N,Diag}
372+ ) where {F,T,N,Diag<: AbstractFill{T} }
373+ # TODO : Check that `f` preserves zeros?
374+ return DiagonalArray (_broadcasted (f, diagview (a)), axes (a))
375+ end
0 commit comments