@@ -26,6 +26,8 @@ Base.BroadcastStyle(::Type{<:AbstractGBVector}) = GBVectorStyle()
26
26
Base. BroadcastStyle (:: Type{<:AbstractGBMatrix} ) = GBMatrixStyle ()
27
27
Base. BroadcastStyle (:: Type{<:Transpose{T, <:AbstractGBMatrix} where T} ) = GBMatrixStyle ()
28
28
Base. BroadcastStyle (:: Type{<:Adjoint{T, <:AbstractGBMatrix} where T} ) = GBMatrixStyle ()
29
+ Base. BroadcastStyle (:: Type{<:Transpose{T, <:AbstractGBVector} where T} ) = GBVectorStyle ()
30
+ Base. BroadcastStyle (:: Type{<:Adjoint{T, <:AbstractGBVector} where T} ) = GBVectorStyle ()
29
31
30
32
#
31
33
GBVectorStyle (:: Val{0} ) = GBVectorStyle ()
@@ -91,14 +93,6 @@ modifying(::typeof(emul)) = emul!
91
93
if right isa Broadcast. Broadcasted
92
94
right = copy (right)
93
95
end
94
- if left isa AbstractVector && right isa GBMatrixOrTranspose &&
95
- ! (size (left, 1 ) == size (right, 1 ) && size (left, 2 ) == size (right, 2 ))
96
- return * (Diagonal (left), right, (any, f))
97
- end
98
- if left isa GBMatrixOrTranspose && right isa Transpose{<: Any , <: AbstractVector } &&
99
- ! (size (left, 1 ) == size (right, 1 ) && size (left, 2 ) == size (right, 2 ))
100
- return * (left, Diagonal (right), (any, f))
101
- end
102
96
if left isa StridedArray
103
97
left = pack (left; fill = right isa GBArrayOrTranspose ? getfill (right) : nothing )
104
98
end
@@ -177,14 +171,6 @@ mutatingop(::typeof(apply)) = apply!
177
171
# If they're further nested broadcasts we can't fuse them, so just copy.
178
172
subargleft isa Broadcast. Broadcasted && (subargleft = copy (subargleft))
179
173
subargright isa Broadcast. Broadcasted && (subargright = copy (subargright))
180
- if left isa AbstractVector && right isa GBMatrixOrTranspose &&
181
- ! (size (left, 1 ) == size (right, 1 ) && size (left, 2 ) == size (right, 2 ))
182
- return * (Diagonal (left), right, (any, f); accum)
183
- end
184
- if left isa GBMatrixOrTranspose && right isa Transpose{<: Any , <: AbstractVector } &&
185
- ! (size (left, 1 ) == size (right, 1 ) && size (left, 2 ) == size (right, 2 ))
186
- return * (left, Diagonal (right), (any, f); accum)
187
- end
188
174
if subargleft isa StridedArray
189
175
subargleft = pack (subargleft; fill = subargright isa GBArrayOrTranspose ? getfill (right) : 0 )
190
176
end
@@ -365,4 +351,69 @@ function Base.materialize!(
365
351
return setindex! (A, bc. args[begin ], :)
366
352
end
367
353
368
- Base. Broadcast. broadcasted (:: Type{T} , A:: AbstractGBArray ) where T = LinearAlgebra. copy_oftype (A, T)
354
+ Base. Broadcast. broadcasted (:: Type{T} , A:: AbstractGBArray ) where T = LinearAlgebra. copy_oftype (A, T)
355
+
356
+ # This is overly verbose, perhaps a macro?
357
+ # return an operator that swaps the order of the operands.
358
+ # * -> *, first -> second, second -> first, - -> rminus, etc.
359
+ _swapop (op) = throw (ArgumentError (" Cannot swap order of operands automatically. Swap the order of the broadcast statement or overload `_swapop`" ))
360
+ _swapop (:: typeof (first)) = second
361
+ _swapop (:: typeof (second)) = first
362
+
363
+ _swapop (:: typeof (any)) = any
364
+
365
+ _swapop (:: typeof (pair)) = pair
366
+
367
+ _swapop (:: typeof (+ )) = +
368
+ _swapop (:: typeof (- )) = rminus
369
+ _swapop (:: typeof (rminus)) = -
370
+
371
+ _swapop (:: typeof (* )) = *
372
+ _swapop (:: typeof (/ )) = \
373
+ _swapop (:: typeof (\ )) = /
374
+
375
+ # ^ / POW doesn't have an equivalent builtin... Error for now.
376
+
377
+ _swapop (:: typeof (iseq)) = iseq
378
+ _swapop (:: typeof (isne)) = isne
379
+
380
+ _swapop (:: typeof (min)) = min
381
+ _swapop (:: typeof (max)) = max
382
+
383
+ _swapop (:: typeof (isgt)) = isle
384
+ _swapop (:: typeof (isle)) = isgt
385
+
386
+ _swapop (:: typeof (isge)) = islt
387
+ _swapop (:: typeof (islt)) = isge
388
+
389
+ _swapop (:: typeof (∨ )) = ∨
390
+ _swapop (:: typeof (∧ )) = ∧
391
+
392
+ _swapop (:: typeof (lxor)) = lxor
393
+ _swapop (:: typeof (xnor)) = xnor
394
+
395
+ _swapop (:: typeof (== )) = ==
396
+ _swapop (:: typeof (!= )) = !=
397
+
398
+ _swapop (:: typeof (> )) = <=
399
+ _swapop (:: typeof (<= )) = >
400
+ _swapop (:: typeof (< )) = >=
401
+ _swapop (:: typeof (>= )) = <
402
+
403
+ # I'm not going to bother with the trig/mod/sign/complex/etc. If you need them please open an issue.
404
+
405
+ _swapop (:: typeof (| )) = |
406
+ _swapop (:: typeof (& )) = &
407
+ _swapop (:: typeof (⊻ )) = ⊻
408
+ _swapop (:: typeof (bxnor)) = bxnor
409
+ # bshift has no obvious equivalent in the builtins
410
+
411
+ _swapop (:: typeof (firsti0)) = secondi0
412
+ _swapop (:: typeof (secondi0)) = firsti0
413
+ _swapop (:: typeof (firsti)) = secondi
414
+ _swapop (:: typeof (secondi)) = firsti
415
+
416
+ _swapop (:: typeof (firstj0)) = secondj0
417
+ _swapop (:: typeof (secondj0)) = firstj0
418
+ _swapop (:: typeof (firstj)) = secondj
419
+ _swapop (:: typeof (secondj)) = firstj
0 commit comments