1- using FillArrays: FillArrays, Zeros
1+ using FillArrays: FillArrays, Ones, Zeros
22function FillArrays. fillsimilar (
33 a:: Zeros{T} ,
44 ax:: Tuple {
@@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
2121const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
2222const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
2323
24+ using DiagonalArrays: Delta
25+ const DeltaKronecker{T,N,A<: Delta{T,N} ,B<: AbstractArray{T,N} } = KroneckerArray{T,N,A,B}
26+ const KroneckerDelta{T,N,A<: AbstractArray{T,N} ,B<: Delta{T,N} } = KroneckerArray{T,N,A,B}
27+ const DeltaDelta{T,N,A<: Delta{T,N} ,B<: Delta{T,N} } = KroneckerArray{T,N,A,B}
28+
2429_getindex (a:: Eye , I1:: Colon , I2:: Colon ) = a
2530_getindex (a:: Eye , I1:: Base.Slice , I2:: Base.Slice ) = a
2631_getindex (a:: Eye , I1:: Base.Slice , I2:: Colon ) = a
@@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
3035_view (a:: Eye , I1:: Base.Slice , I2:: Colon ) = a
3136_view (a:: Eye , I1:: Colon , I2:: Base.Slice ) = a
3237
38+ function _getindex (a:: Delta , I1:: Union{Colon,Base.Slice} , Irest:: Union{Colon,Base.Slice} ...)
39+ return a
40+ end
41+ function _view (a:: Delta , I1:: Union{Colon,Base.Slice} , Irest:: Union{Colon,Base.Slice} ...)
42+ return a
43+ end
44+
3345# Like `adapt` but preserves `Eye`.
3446_adapt (to, a:: Eye ) = a
47+ _adapt (to, a:: Delta ) = a
3548
3649# Allows customizing for `FillArrays.Eye`.
3750function _convert (:: Type{AbstractArray{T}} , a:: RectDiagonal ) where {T}
38- _convert (AbstractMatrix{T}, a)
51+ return _convert (AbstractMatrix{T}, a)
3952end
4053function _convert (:: Type{AbstractMatrix{T}} , a:: RectDiagonal ) where {T}
41- RectDiagonal (convert (AbstractVector{T}, _diagview (a)), axes (a))
54+ return RectDiagonal (convert (AbstractVector{T}, _diagview (a)), axes (a))
4255end
4356
4457# Like `similar` but preserves `Eye`, `Ones`, etc.
@@ -61,8 +74,33 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange})
6174 return Eye {eltype(arrayt)} ((only (unique (axs)),))
6275end
6376
64- # Like `copy` but preserves `Eye`.
77+ function _similar (a:: Delta , elt:: Type , axs:: Tuple{Vararg{AbstractUnitRange}} )
78+ return Delta {elt} (axs)
79+ end
80+ function _similar (arrayt:: Type{<:Delta} , axs:: Tuple{Vararg{AbstractUnitRange}} )
81+ return Delta {eltype(arrayt)} (axs)
82+ end
83+
84+ # Like `copy` but preserves `Eye`/`Delta`.
6585_copy (a:: Eye ) = a
86+ _copy (a:: Delta ) = a
87+
88+ function _copyto!! (dest:: Eye{<:Any,N} , src:: Eye{<:Any,N} ) where {N}
89+ size (dest) == size (src) ||
90+ throw (ArgumentError (" Sizes do not match: $(size (dest)) != $(size (src)) ." ))
91+ return dest
92+ end
93+ function _copyto!! (dest:: Delta{<:Any,N} , src:: Delta{<:Any,N} ) where {N}
94+ size (dest) == size (src) ||
95+ throw (ArgumentError (" Sizes do not match: $(size (dest)) != $(size (src)) ." ))
96+ return dest
97+ end
98+
99+ function _permutedims!! (dest:: Delta , src:: Delta , perm)
100+ Base. PermutedDimsArrays. genperm (axes (src), perm) == axes (dest) ||
101+ throw (ArgumentError (" Permuted axes do not match." ))
102+ return dest
103+ end
66104
67105using Base. Broadcast:
68106 AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
75113Base. BroadcastStyle (style1:: EyeStyle , style2:: EyeStyle ) = EyeStyle ()
76114Base. BroadcastStyle (style1:: EyeStyle , style2:: DefaultArrayStyle ) = style2
77115
116+ function _copyto!! (dest:: Eye , src:: Broadcasted{<:EyeStyle,<:Any,typeof(identity)} )
117+ axes (dest) == axes (src) || error (" Dimension mismatch." )
118+ return dest
119+ end
120+
78121function Base. similar (bc:: Broadcasted{EyeStyle} , elt:: Type )
79122 return Eye {elt} (axes (bc))
80123end
81124
125+ # TODO : Define in terms of `_copyto!!` that is called on each argument.
82126function Base. copyto! (dest:: EyeKronecker , a:: Sum{<:KroneckerStyle{<:Any,EyeStyle()}} )
83127 dest2 = arg2 (dest)
84128 f = LinearCombination (a)
@@ -99,6 +143,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye
99143 return error (" Can't write in-place to `Eye ⊗ Eye`." )
100144end
101145
146+ struct DeltaStyle{N} <: AbstractArrayStyle{N} end
147+ DeltaStyle (:: Val{N} ) where {N} = DeltaStyle {N} ()
148+ DeltaStyle {M} (:: Val{N} ) where {M,N} = DeltaStyle {N} ()
149+ function _BroadcastStyle (A:: Type{<:Delta} )
150+ return DeltaStyle {ndims(A)} ()
151+ end
152+ Base. BroadcastStyle (style1:: DeltaStyle , style2:: DeltaStyle ) = DeltaStyle ()
153+ Base. BroadcastStyle (style1:: DeltaStyle , style2:: DefaultArrayStyle ) = style2
154+
155+ function _copyto!! (dest:: Delta , src:: Broadcasted{<:DeltaStyle,<:Any,typeof(identity)} )
156+ axes (dest) == axes (src) || error (" Dimension mismatch." )
157+ return dest
158+ end
159+
160+ function Base. similar (bc:: Broadcasted{<:DeltaStyle} , elt:: Type )
161+ return Delta {elt} (axes (bc))
162+ end
163+
164+ # TODO : Dispatch on `DeltaStyle`.
165+ function Base. copyto! (dest:: DeltaKronecker , a:: Sum{<:KroneckerStyle} )
166+ dest2 = arg2 (dest)
167+ f = LinearCombination (a)
168+ args = arguments (a)
169+ arg2s = arg2 .(args)
170+ dest2 .= f .(arg2s... )
171+ return dest
172+ end
173+ # TODO : Dispatch on `DeltaStyle`.
174+ function Base. copyto! (dest:: KroneckerDelta , a:: Sum{<:KroneckerStyle} )
175+ dest1 = arg1 (dest)
176+ f = LinearCombination (a)
177+ args = arguments (a)
178+ arg1s = arg1 .(args)
179+ dest1 .= f .(arg1s... )
180+ return dest
181+ end
182+ # TODO : Dispatch on `DeltaStyle`.
183+ function Base. copyto! (dest:: DeltaDelta , a:: Sum{<:KroneckerStyle} )
184+ return error (" Can't write in-place to `Delta ⊗ Delta`." )
185+ end
186+
102187# Simplification rules similar to those for FillArrays.jl:
103188# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
104189using FillArrays: Zeros
0 commit comments