1- using FillArrays: FillArrays, Zeros
1+ using FillArrays: FillArrays, Ones, Zeros
22function FillArrays. fillsimilar (
33 a:: Zeros{T} ,
44 ax:: Tuple {
@@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
2121const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
2222const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
2323
24+ using DiagonalArrays: Delta
25+ const DeltaKronecker{T,N,A<: Delta{T,N} ,B<: AbstractArray{T,N} } = KroneckerArray{T,N,A,B}
26+ const KroneckerDelta{T,N,A<: AbstractArray{T,N} ,B<: Delta{T,N} } = KroneckerArray{T,N,A,B}
27+ const DeltaDelta{T,N,A<: Delta{T,N} ,B<: Delta{T,N} } = KroneckerArray{T,N,A,B}
28+
2429_getindex (a:: Eye , I1:: Colon , I2:: Colon ) = a
2530_getindex (a:: Eye , I1:: Base.Slice , I2:: Base.Slice ) = a
2631_getindex (a:: Eye , I1:: Base.Slice , I2:: Colon ) = a
@@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
3035_view (a:: Eye , I1:: Base.Slice , I2:: Colon ) = a
3136_view (a:: Eye , I1:: Colon , I2:: Base.Slice ) = a
3237
38+ function _getindex (a:: Delta , I1:: Union{Colon,Base.Slice} , Irest:: Union{Colon,Base.Slice} ...)
39+ return a
40+ end
41+ function _view (a:: Delta , I1:: Union{Colon,Base.Slice} , Irest:: Union{Colon,Base.Slice} ...)
42+ return a
43+ end
44+
3345# Like `adapt` but preserves `Eye`.
3446_adapt (to, a:: Eye ) = a
47+ _adapt (to, a:: Delta ) = a
3548
3649# Allows customizing for `FillArrays.Eye`.
3750function _convert (:: Type{AbstractArray{T}} , a:: RectDiagonal ) where {T}
38- _convert (AbstractMatrix{T}, a)
51+ return _convert (AbstractMatrix{T}, a)
3952end
4053function _convert (:: Type{AbstractMatrix{T}} , a:: RectDiagonal ) where {T}
41- RectDiagonal (convert (AbstractVector{T}, _diagview (a)), axes (a))
54+ return RectDiagonal (convert (AbstractVector{T}, _diagview (a)), axes (a))
4255end
4356
4457# Like `similar` but preserves `Eye`.
@@ -74,8 +87,39 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange})
7487 return Eye {eltype(arrayt)} ((only (unique (axs)),))
7588end
7689
77- # Like `copy` but preserves `Eye`.
90+ function _similar (a:: Delta , elt:: Type , axs:: Tuple{Vararg{AbstractUnitRange}} )
91+ return Delta {elt} (axs)
92+ end
93+ function _similar (arrayt:: Type{<:Delta} , axs:: Tuple{Vararg{AbstractUnitRange}} )
94+ return Delta {eltype(arrayt)} (axs)
95+ end
96+
97+ # Like `copy` but preserves `Eye`/`Delta`.
7898_copy (a:: Eye ) = a
99+ _copy (a:: Delta ) = a
100+
101+ function _copyto!! (dest:: Eye{<:Any,N} , src:: Eye{<:Any,N} ) where {N}
102+ size (dest) == size (src) ||
103+ throw (ArgumentError (" Sizes do not match: $(size (dest)) != $(size (src)) ." ))
104+ return dest
105+ end
106+ function _copyto!! (dest:: Delta{<:Any,N} , src:: Delta{<:Any,N} ) where {N}
107+ size (dest) == size (src) ||
108+ throw (ArgumentError (" Sizes do not match: $(size (dest)) != $(size (src)) ." ))
109+ return dest
110+ end
111+
112+ # TODO : Define `DerivableInterfaces.permuteddims` and overload that instead.
113+ function Base. PermutedDimsArray (a:: Delta , perm)
114+ ax_perm = Base. PermutedDimsArrays. genperm (axes (a), perm)
115+ return Delta {eltype(a)} (ax_perm)
116+ end
117+
118+ function _permutedims!! (dest:: Delta , src:: Delta , perm)
119+ Base. PermutedDimsArrays. genperm (axes (src), perm) == axes (dest) ||
120+ throw (ArgumentError (" Permuted axes do not match." ))
121+ return dest
122+ end
79123
80124using DerivableInterfaces: DerivableInterfaces, zero!
81125function DerivableInterfaces. zero! (a:: EyeKronecker )
@@ -90,6 +134,18 @@ function DerivableInterfaces.zero!(a::EyeEye)
90134 return throw (ArgumentError (" Can't zero out `Eye ⊗ Eye`." ))
91135end
92136
137+ function DerivableInterfaces. zero! (a:: DeltaKronecker )
138+ zero! (a. b)
139+ return a
140+ end
141+ function DerivableInterfaces. zero! (a:: KroneckerDelta )
142+ zero! (a. a)
143+ return a
144+ end
145+ function DerivableInterfaces. zero! (a:: DeltaDelta )
146+ return throw (ArgumentError (" Can't zero out `Delta ⊗ Delta`." ))
147+ end
148+
93149using Base. Broadcast:
94150 AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
95151
@@ -101,10 +157,16 @@ end
101157Base. BroadcastStyle (style1:: EyeStyle , style2:: EyeStyle ) = EyeStyle ()
102158Base. BroadcastStyle (style1:: EyeStyle , style2:: DefaultArrayStyle ) = style2
103159
160+ function _copyto!! (dest:: Eye , src:: Broadcasted{<:EyeStyle,<:Any,typeof(identity)} )
161+ axes (dest) == axes (src) || error (" Dimension mismatch." )
162+ return dest
163+ end
164+
104165function Base. similar (bc:: Broadcasted{EyeStyle} , elt:: Type )
105166 return Eye {elt} (axes (bc))
106167end
107168
169+ # TODO : Define in terms of `_copyto!!` that is called on each argument.
108170function Base. copyto! (dest:: EyeKronecker , a:: Sum{<:KroneckerStyle{<:Any,EyeStyle()}} )
109171 dest2 = arg2 (dest)
110172 f = LinearCombination (a)
@@ -125,6 +187,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye
125187 return error (" Can't write in-place to `Eye ⊗ Eye`." )
126188end
127189
190+ struct DeltaStyle{N} <: AbstractArrayStyle{N} end
191+ DeltaStyle (:: Val{N} ) where {N} = DeltaStyle {N} ()
192+ DeltaStyle {M} (:: Val{N} ) where {M,N} = DeltaStyle {N} ()
193+ function _BroadcastStyle (A:: Type{<:Delta} )
194+ return DeltaStyle {ndims(A)} ()
195+ end
196+ Base. BroadcastStyle (style1:: DeltaStyle , style2:: DeltaStyle ) = DeltaStyle ()
197+ Base. BroadcastStyle (style1:: DeltaStyle , style2:: DefaultArrayStyle ) = style2
198+
199+ function _copyto!! (dest:: Delta , src:: Broadcasted{<:DeltaStyle,<:Any,typeof(identity)} )
200+ axes (dest) == axes (src) || error (" Dimension mismatch." )
201+ return dest
202+ end
203+
204+ function Base. similar (bc:: Broadcasted{<:DeltaStyle} , elt:: Type )
205+ return Delta {elt} (axes (bc))
206+ end
207+
208+ # TODO : Dispatch on `DeltaStyle`.
209+ function Base. copyto! (dest:: DeltaKronecker , a:: Sum{<:KroneckerStyle} )
210+ dest2 = arg2 (dest)
211+ f = LinearCombination (a)
212+ args = arguments (a)
213+ arg2s = arg2 .(args)
214+ dest2 .= f .(arg2s... )
215+ return dest
216+ end
217+ # TODO : Dispatch on `DeltaStyle`.
218+ function Base. copyto! (dest:: KroneckerDelta , a:: Sum{<:KroneckerStyle} )
219+ dest1 = arg1 (dest)
220+ f = LinearCombination (a)
221+ args = arguments (a)
222+ arg1s = arg1 .(args)
223+ dest1 .= f .(arg1s... )
224+ return dest
225+ end
226+ # TODO : Dispatch on `DeltaStyle`.
227+ function Base. copyto! (dest:: DeltaDelta , a:: Sum{<:KroneckerStyle} )
228+ return error (" Can't write in-place to `Delta ⊗ Delta`." )
229+ end
230+
128231# Simplification rules similar to those for FillArrays.jl:
129232# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
130233using FillArrays: Zeros
0 commit comments