@@ -81,144 +81,6 @@ function DerivableInterfaces.zero!(a::EyeEye)
81
81
return throw (ArgumentError (" Can't zero out `Eye ⊗ Eye`." ))
82
82
end
83
83
84
- function Base.:* (a:: Number , b:: EyeKronecker )
85
- return b. a ⊗ (a * b. b)
86
- end
87
- function Base.:* (a:: Number , b:: KroneckerEye )
88
- return (a * b. a) ⊗ b. b
89
- end
90
- function Base.:* (a:: Number , b:: EyeEye )
91
- return error (" Can't multiply `Eye ⊗ Eye` by a number." )
92
- end
93
- function Base.:* (a:: EyeKronecker , b:: Number )
94
- return a. a ⊗ (a. b * b)
95
- end
96
- function Base.:* (a:: KroneckerEye , b:: Number )
97
- return (a. a * b) ⊗ a. b
98
- end
99
- function Base.:* (a:: EyeEye , b:: Number )
100
- return error (" Can't multiply `Eye ⊗ Eye` by a number." )
101
- end
102
-
103
- function Base.:- (a:: EyeKronecker )
104
- return a. a ⊗ (- a. b)
105
- end
106
- function Base.:- (a:: KroneckerEye )
107
- return (- a. a) ⊗ a. b
108
- end
109
- function Base.:- (a:: EyeEye )
110
- return error (" Can't multiply `Eye ⊗ Eye` by a number." )
111
- end
112
-
113
- for op in (:+ , :- )
114
- @eval begin
115
- function Base. $op (a:: EyeKronecker , b:: EyeKronecker )
116
- if a. a ≠ b. a
117
- return throw (
118
- ArgumentError (
119
- " KroneckerArray addition is only supported when the first or secord arguments match." ,
120
- ),
121
- )
122
- end
123
- return a. a ⊗ $ op (a. b, b. b)
124
- end
125
- function Base. $op (a:: KroneckerEye , b:: KroneckerEye )
126
- if a. b ≠ b. b
127
- return throw (
128
- ArgumentError (
129
- " KroneckerArray addition is only supported when the first or secord arguments match." ,
130
- ),
131
- )
132
- end
133
- return $ op (a. a, b. a) ⊗ a. b
134
- end
135
- function Base. $op (a:: EyeEye , b:: EyeEye )
136
- if a. b ≠ b. b
137
- return throw (
138
- ArgumentError (
139
- " KroneckerArray addition is only supported when the first or secord arguments match." ,
140
- ),
141
- )
142
- end
143
- return $ op (a. a, b. a) ⊗ a. b
144
- end
145
- end
146
- end
147
-
148
- function Base. map! (f:: typeof (identity), dest:: EyeKronecker , src:: EyeKronecker )
149
- map! (f, dest. b, src. b)
150
- return dest
151
- end
152
- function Base. map! (f:: typeof (identity), dest:: KroneckerEye , src:: KroneckerEye )
153
- map! (f, dest. a, src. a)
154
- return dest
155
- end
156
- function Base. map! (:: typeof (identity), dest:: EyeEye , src:: EyeEye )
157
- return error (" Can't write in-place." )
158
- end
159
- for f in [:+ , :- ]
160
- @eval begin
161
- function Base. map! (:: typeof ($ f), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
162
- if dest. a ≠ a. a ≠ b. a
163
- throw (
164
- ArgumentError (
165
- " KroneckerArray addition is only supported when the first or second arguments match." ,
166
- ),
167
- )
168
- end
169
- map! ($ f, dest. b, a. b, b. b)
170
- return dest
171
- end
172
- function Base. map! (:: typeof ($ f), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
173
- if dest. b ≠ a. b ≠ b. b
174
- throw (
175
- ArgumentError (
176
- " KroneckerArray addition is only supported when the first or second arguments match." ,
177
- ),
178
- )
179
- end
180
- map! ($ f, dest. a, a. a, b. a)
181
- return dest
182
- end
183
- function Base. map! (:: typeof ($ f), dest:: EyeEye , a:: EyeEye , b:: EyeEye )
184
- return error (" Can't write in-place." )
185
- end
186
- end
187
- end
188
- function Base. map! (f:: typeof (- ), dest:: EyeKronecker , a:: EyeKronecker )
189
- map! (f, dest. b, a. b)
190
- return dest
191
- end
192
- function Base. map! (f:: typeof (- ), dest:: KroneckerEye , a:: KroneckerEye )
193
- map! (f, dest. a, a. a)
194
- return dest
195
- end
196
- function Base. map! (f:: typeof (- ), dest:: EyeEye , a:: EyeEye )
197
- return error (" Can't write in-place." )
198
- end
199
- function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
200
- map! (f, dest. b, a. b)
201
- return dest
202
- end
203
- function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerEye , a:: KroneckerEye )
204
- map! (f, dest. a, a. a)
205
- return dest
206
- end
207
- function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeEye , a:: EyeEye )
208
- return error (" Can't write in-place." )
209
- end
210
- function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
211
- map! (f, dest. b, a. b)
212
- return dest
213
- end
214
- function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: KroneckerEye , a:: KroneckerEye )
215
- map! (f, dest. a, a. a)
216
- return dest
217
- end
218
- function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: EyeEye , a:: EyeEye )
219
- return error (" Can't write in-place." )
220
- end
221
-
222
84
using Base. Broadcast:
223
85
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
224
86
@@ -233,3 +95,82 @@ Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2
233
95
function Base. similar (bc:: Broadcasted{EyeStyle} , elt:: Type )
234
96
return Eye {elt} (axes (bc))
235
97
end
98
+
99
+ function Base. copyto! (dest:: EyeKronecker , a:: Sum{<:KroneckerStyle{<:Any,EyeStyle()}} )
100
+ dest2 = arg2 (dest)
101
+ f = LinearCombination (a)
102
+ args = arguments (a)
103
+ arg2s = arg2 .(args)
104
+ dest2 .= f .(arg2s... )
105
+ return dest
106
+ end
107
+ function Base. copyto! (dest:: KroneckerEye , a:: Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}} )
108
+ dest1 = arg1 (dest)
109
+ f = LinearCombination (a)
110
+ args = arguments (a)
111
+ arg1s = arg1 .(args)
112
+ dest1 .= f .(arg1s... )
113
+ return dest
114
+ end
115
+ function Base. copyto! (dest:: EyeEye , a:: Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}} )
116
+ return error (" Can't write in-place to `Eye ⊗ Eye`." )
117
+ end
118
+
119
+ # Simplification rules similar to those for FillArrays.jl:
120
+ # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
121
+ using FillArrays: Zeros
122
+ function Base. broadcasted (
123
+ style:: KroneckerStyle ,
124
+ :: typeof (+ ),
125
+ a:: KroneckerArray ,
126
+ b:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
127
+ )
128
+ # TODO : Promote the element types.
129
+ return a
130
+ end
131
+ function Base. broadcasted (
132
+ style:: KroneckerStyle ,
133
+ :: typeof (+ ),
134
+ a:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
135
+ b:: KroneckerArray ,
136
+ )
137
+ # TODO : Promote the element types.
138
+ return b
139
+ end
140
+ function Base. broadcasted (
141
+ style:: KroneckerStyle ,
142
+ :: typeof (+ ),
143
+ a:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
144
+ b:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
145
+ )
146
+ # TODO : Promote the element types and axes.
147
+ return b
148
+ end
149
+ function Base. broadcasted (
150
+ style:: KroneckerStyle ,
151
+ :: typeof (- ),
152
+ a:: KroneckerArray ,
153
+ b:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
154
+ )
155
+ # TODO : Promote the element types.
156
+ return a
157
+ end
158
+ function Base. broadcasted (
159
+ style:: KroneckerStyle ,
160
+ :: typeof (- ),
161
+ a:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
162
+ b:: KroneckerArray ,
163
+ )
164
+ # TODO : Promote the element types.
165
+ # TODO : Return `broadcasted(-, b)`.
166
+ return - b
167
+ end
168
+ function Base. broadcasted (
169
+ style:: KroneckerStyle ,
170
+ :: typeof (- ),
171
+ a:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
172
+ b:: KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros} ,
173
+ )
174
+ # TODO : Promote the element types and axes.
175
+ return b
176
+ end
0 commit comments