226
226
for op in (:+ , :- )
227
227
@eval begin
228
228
function Base. $op (a:: KroneckerArray , b:: KroneckerArray )
229
+ iszero (a) && return $ op (b)
230
+ iszero (b) && return a
229
231
if a. b == b. b
230
232
return $ op (a. a, b. a) ⊗ a. b
231
233
elseif a. a == b. a
@@ -241,8 +243,15 @@ for op in (:+, :-)
241
243
end
242
244
end
243
245
244
- using Base. Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
246
+ # Allows for customizations for FillArrays.
247
+ _BroadcastStyle (x) = BroadcastStyle (x)
248
+
249
+ using Base. Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
245
250
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
251
+ arg1 (:: Type{<:KroneckerStyle{<:Any,A}} ) where {A} = A
252
+ arg1 (style:: KroneckerStyle ) = arg1 (typeof (style))
253
+ arg2 (:: Type{<:KroneckerStyle{<:Any,B}} ) where {B} = B
254
+ arg2 (style:: KroneckerStyle ) = arg2 (typeof (style))
246
255
function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
247
256
return KroneckerStyle {N,a,b} ()
248
257
end
@@ -253,30 +262,69 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
253
262
return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
254
263
end
255
264
function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
256
- return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
265
+ return KroneckerStyle {N} (_BroadcastStyle (A), _BroadcastStyle (B))
257
266
end
258
267
function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
259
- return KroneckerStyle {N} (
260
- BroadcastStyle (style1. a, style2. a), BroadcastStyle (style1. b, style2. b)
261
- )
268
+ style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
269
+ (style_a isa Broadcast. Unknown) && return Broadcast. Unknown ()
270
+ style_b = BroadcastStyle (arg2 (style1), arg2 (style2))
271
+ (style_b isa Broadcast. Unknown) && return Broadcast. Unknown ()
272
+ return KroneckerStyle {N} (style_a, style_b)
262
273
end
263
274
function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type ) where {N,A,B}
264
- ax_a = map (ax -> ax . product . a, axes (bc))
265
- ax_b = map (ax -> ax . product . b, axes (bc))
275
+ ax_a = arg1 .( axes (bc))
276
+ ax_b = arg2 .( axes (bc))
266
277
bc_a = Broadcasted (A, nothing , (), ax_a)
267
278
bc_b = Broadcasted (B, nothing , (), ax_b)
268
279
a = similar (bc_a, elt)
269
280
b = similar (bc_b, elt)
270
281
return a ⊗ b
271
282
end
283
+ # Fallback definition of broadcasting falls back to `map` but assumes
284
+ # inputs have been canonicalized to a map-compatible expression already,
285
+ # for example by absorbing scalar arguments into the function.
272
286
function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:KroneckerStyle} )
273
- return throw (
274
- ArgumentError (
275
- " Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
276
- ),
277
- )
287
+ allequal (axes, bc. args) || throw (ArgumentError (" Broadcasted axes must be equal." ))
288
+ map! (bc. f, dest, bc. args... )
289
+ return dest
278
290
end
279
291
292
+ # Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the
293
+ # function.
294
+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (* ), a:: Number , b:: KroneckerArray )
295
+ return broadcasted (style, Base. Fix1 (* , a), b)
296
+ end
297
+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (* ), a:: KroneckerArray , b:: Number )
298
+ return broadcasted (style, Base. Fix2 (* , b), a)
299
+ end
300
+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (/ ), a:: KroneckerArray , b:: Number )
301
+ return broadcasted (style, Base. Fix2 (/ , b), a)
302
+ end
303
+ using MapBroadcast: MapBroadcast, MapFunction
304
+ function Base. broadcasted (
305
+ style:: KroneckerStyle ,
306
+ f:: MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}} ,
307
+ a:: KroneckerArray ,
308
+ )
309
+ return broadcasted (style, Base. Fix1 (* , f. args[1 ]), a)
310
+ end
311
+ function Base. broadcasted (
312
+ style:: KroneckerStyle ,
313
+ f:: MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}} ,
314
+ a:: KroneckerArray ,
315
+ )
316
+ return broadcasted (style, Base. Fix2 (* , f. args[2 ]), a)
317
+ end
318
+ function Base. broadcasted (
319
+ style:: KroneckerStyle ,
320
+ f:: MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}} ,
321
+ a:: KroneckerArray ,
322
+ )
323
+ return broadcasted (style, Base. Fix2 (/ , f. args[2 ]), a)
324
+ end
325
+
326
+ # TODO : Define by converting to a broadcast expession (with MapBroadcast.jl)
327
+ # and then constructing the output with `similar`.
280
328
function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray... )
281
329
return throw (
282
330
ArgumentError (
@@ -312,6 +360,8 @@ for f in [:+, :-]
312
360
function Base. map! (
313
361
:: typeof ($ f), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray
314
362
)
363
+ iszero (b) && return map! (identity, dest, a)
364
+ iszero (a) && return map! ($ f, dest, b)
315
365
if a. b == b. b
316
366
map! ($ f, dest. a, a. a, b. a)
317
367
map! (identity, dest. b, a. b)
@@ -350,6 +400,15 @@ for op in [:*, :/]
350
400
end
351
401
end
352
402
end
403
+ for f in [:+ , :- ]
404
+ @eval begin
405
+ function Base. map! (:: typeof ($ f), dest:: KroneckerArray , src:: KroneckerArray )
406
+ map! ($ f, dest. a, src. a)
407
+ map! (identity, dest. b, src. b)
408
+ return dest
409
+ end
410
+ end
411
+ end
353
412
354
413
using DiagonalArrays: DiagonalArrays, diagonal
355
414
function DiagonalArrays. diagonal (a:: KroneckerArray )
0 commit comments