Skip to content

Commit f72a06b

Browse files
committed
correct some mask issues, improve bcast support
1 parent 9aed241 commit f72a06b

File tree

5 files changed

+285
-39
lines changed

5 files changed

+285
-39
lines changed

src/abstractgbarray.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ Base.promote_rule(::Type{<:AbstractGBVector{T, F}}, ::Type{<:AbstractGBVector{T2
5454
GBVector{promote_type(T, T2), promote_type(F, F2)}
5555

5656
function gbpromote_strip(A, B)
57-
if A isa Transpose{<:Any, <:AbstractGBVector} && B isa Transpose{<:Any, <:AbstractGBVector}
58-
return GBMatrix
59-
else
60-
return promote_type(strip_parameters(typeof(parent(A))), strip_parameters(typeof(parent(B))))
61-
end
57+
return promote_type(strip_parameters(typeof(parent(A))), strip_parameters(typeof(parent(B))))
6258
end
59+
gbpromote_strip(::Transpose{<:Any, <:AbstractGBVector}, ::Transpose{<:Any, <:AbstractGBVector}) = GBMatrix
60+
gbpromote_strip(::AbstractGBVector, ::Transpose{<:Any, <:AbstractGBVector}) = GBMatrix
61+
gbpromote_strip(::Transpose{<:Any, <:AbstractGBVector}, ::AbstractGBVector) = GBMatrix
62+
63+
6364

6465
Base.IndexStyle(::AbstractGBArray) = IndexCartesian()
6566
Base.eltype(::AbstractGBArray{T, F}) where {T, F} = Union{T, F}
@@ -534,6 +535,7 @@ function subassign!(
534535
J, nj = idx(J)
535536
desc = _handledescriptor(desc; out=C, in1=A)
536537
mask = _handlemask!(desc, mask)
538+
mask isa AbstractVector && length(I) == 1 && (mask = copy(mask'))
537539
I = decrement!(I)
538540
I !== J && (J = decrement!(J))
539541
rereshape = false
@@ -568,6 +570,7 @@ function subassign!(C::AbstractGBArray{T}, x, I, J;
568570
I !== J && (J = decrement!(J))
569571
desc = _handledescriptor(desc; out=C)
570572
mask = _handlemask!(desc, mask)
573+
mask isa AbstractVector && length(I) == 1 && (mask = copy(mask'))
571574
_subassign(C, x, I, ni, J, nj, mask, _handleaccum(accum, storedeltype(C)), desc)
572575
increment!(I)
573576
I !== J && (decrement!(J))
@@ -637,6 +640,7 @@ function assign!(
637640
J, nj = idx(J)
638641
desc = _handledescriptor(desc; in1=A, out=C)
639642
mask = _handlemask!(desc, mask)
643+
mask isa AbstractVector && length(I) == 1 && (mask = copy(mask'))
640644
I = decrement!(I)
641645
I !== J && (J = decrement!(J))
642646
if !(eltype(A) <: valid_union) || !(eltype(C) <: valid_union)
@@ -659,6 +663,7 @@ function assign!(C::AbstractGBArray{T}, x, I, J;
659663
I !== J && (J = decrement!(J))
660664
desc = _handledescriptor(desc; out=C)
661665
mask = _handlemask!(desc, mask)
666+
mask isa AbstractVector && length(I) == 1 && (mask = copy(mask'))
662667
_assign(C, x, I, ni, J, nj, mask, _handleaccum(accum, storedeltype(C)), desc)
663668
increment!(I)
664669
I !== J && (decrement!(J))

src/indexutils.jl

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,46 @@ function idx(I)
3434
end
3535
end
3636

37-
# This function assumes that szA and szB are
38-
# technically equal and that
39-
# 1 <= length(szA | szB) <= 2
40-
# size checks should be done elsewhere.
41-
function _combinesizes(A, B)
42-
if A isa Transpose{<:Any, <:AbstractVector} && B isa Transpose{<:Any, <:AbstractVector}
43-
return size(A)
44-
end
45-
if (A isa AbstractVector && B isa AbstractMatrix) ||
46-
(B isa AbstractVector && A isa AbstractMatrix)
47-
return (size(A, 1), size(A, 2))
48-
else
49-
return size(A)
50-
end
37+
# Combine sizes for bcasting purposes
38+
# This is quite inelegant :(, does this already exist somewhere?
39+
function _combinesizes(A::AbstractGBVector, B::AbstractGBVector)
40+
size(A) == size(B) && (return size(A)) # same size
41+
size(A, 1) == 1 && (return size(B)) # bcast A into B
42+
size(B, 1) == 1 && (return size(A)) # bcast B into A
43+
throw(DimensionMismatch("Got mismatched dimensions $(size(A)), $(size(B))"))
44+
end
45+
function _combinesizes(A::Transpose{<:Any, <:AbstractGBVector}, B::Transpose{<:Any, <:AbstractGBVector})
46+
size(A) == size(B) && (return size(A)) # same size
47+
size(A, 2) == 1 && (return size(B)) # bcast A into B
48+
size(B, 2) == 1 && (return size(A)) # bcast B into A
49+
throw(DimensionMismatch("Got mismatched dimensions $(size(A)), $(size(B))"))
5150
end
51+
# Outer products (dot is done by mul[!])
52+
function _combinesizes(A::AbstractGBVector, B::Transpose{<:Any, <:AbstractGBVector})
53+
return (size(A, 1), size(B, 2))
54+
end
55+
function _combinesizes(A::Transpose{<:Any, <:AbstractGBVector}, B::AbstractGBVector)
56+
return (size(B, 1), size(A, 2))
57+
end
58+
59+
function _combinesizes(A::GBMatrixOrTranspose, v::AbstractGBVector)
60+
length(v)
61+
size(A, 1) == size(v, 1) && (return size(A))
62+
throw(DimensionMismatch("Got mismatched dimensions $(size(A, 1)) and $(size(v, 1))"))
63+
end
64+
function _combinesizes(v::AbstractGBVector, A::GBMatrixOrTranspose)
65+
size(A, 1) == size(v, 1) && (return size(A))
66+
throw(DimensionMismatch("Got mismatched dimensions $(size(v, 1)) and $(size(A, 1))"))
67+
end
68+
69+
function _combinesizes(A::GBMatrixOrTranspose, v::Transpose{<:Any, <:AbstractGBVector})
70+
size(A, 2) == size(v, 2) && (return size(A))
71+
throw(DimensionMismatch("Got mismatched dimensions $(size(A, 2)) and $(size(v, 2))"))
72+
end
73+
function _combinesizes(v::Transpose{<:Any, <:AbstractGBVector}, A::GBMatrixOrTranspose)
74+
size(A, 2) == size(v, 2) && (return size(A))
75+
throw(DimensionMismatch("Got mismatched dimensions $(size(v, 2)) and $(size(A, 2))"))
76+
end
77+
78+
_combinesizes(A::GBMatrixOrTranspose, B::GBMatrixOrTranspose) = size(A) == size(B) ? size(A) :
79+
throw(DimensionMismatch("Got mismatched dimensions: $(size(A)) and $(size(B))"))

src/operations/broadcasts.jl

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Base.BroadcastStyle(::Type{<:AbstractGBVector}) = GBVectorStyle()
2626
Base.BroadcastStyle(::Type{<:AbstractGBMatrix}) = GBMatrixStyle()
2727
Base.BroadcastStyle(::Type{<:Transpose{T, <:AbstractGBMatrix} where T}) = GBMatrixStyle()
2828
Base.BroadcastStyle(::Type{<:Adjoint{T, <:AbstractGBMatrix} where T}) = GBMatrixStyle()
29+
Base.BroadcastStyle(::Type{<:Transpose{T, <:AbstractGBVector} where T}) = GBVectorStyle()
30+
Base.BroadcastStyle(::Type{<:Adjoint{T, <:AbstractGBVector} where T}) = GBVectorStyle()
2931

3032
#
3133
GBVectorStyle(::Val{0}) = GBVectorStyle()
@@ -91,14 +93,6 @@ modifying(::typeof(emul)) = emul!
9193
if right isa Broadcast.Broadcasted
9294
right = copy(right)
9395
end
94-
if left isa AbstractVector && right isa GBMatrixOrTranspose &&
95-
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
96-
return *(Diagonal(left), right, (any, f))
97-
end
98-
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractVector} &&
99-
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
100-
return *(left, Diagonal(right), (any, f))
101-
end
10296
if left isa StridedArray
10397
left = pack(left; fill = right isa GBArrayOrTranspose ? getfill(right) : nothing)
10498
end
@@ -177,14 +171,6 @@ mutatingop(::typeof(apply)) = apply!
177171
# If they're further nested broadcasts we can't fuse them, so just copy.
178172
subargleft isa Broadcast.Broadcasted && (subargleft = copy(subargleft))
179173
subargright isa Broadcast.Broadcasted && (subargright = copy(subargright))
180-
if left isa AbstractVector && right isa GBMatrixOrTranspose &&
181-
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
182-
return *(Diagonal(left), right, (any, f); accum)
183-
end
184-
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractVector} &&
185-
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
186-
return *(left, Diagonal(right), (any, f); accum)
187-
end
188174
if subargleft isa StridedArray
189175
subargleft = pack(subargleft; fill = subargright isa GBArrayOrTranspose ? getfill(right) : 0)
190176
end
@@ -365,4 +351,69 @@ function Base.materialize!(
365351
return setindex!(A, bc.args[begin], :)
366352
end
367353

368-
Base.Broadcast.broadcasted(::Type{T}, A::AbstractGBArray) where T = LinearAlgebra.copy_oftype(A, T)
354+
Base.Broadcast.broadcasted(::Type{T}, A::AbstractGBArray) where T = LinearAlgebra.copy_oftype(A, T)
355+
356+
# This is overly verbose, perhaps a macro?
357+
# return an operator that swaps the order of the operands.
358+
# * -> *, first -> second, second -> first, - -> rminus, etc.
359+
_swapop(op) = throw(ArgumentError("Cannot swap order of operands automatically. Swap the order of the broadcast statement or overload `_swapop`"))
360+
_swapop(::typeof(first)) = second
361+
_swapop(::typeof(second)) = first
362+
363+
_swapop(::typeof(any)) = any
364+
365+
_swapop(::typeof(pair)) = pair
366+
367+
_swapop(::typeof(+)) = +
368+
_swapop(::typeof(-)) = rminus
369+
_swapop(::typeof(rminus)) = -
370+
371+
_swapop(::typeof(*)) = *
372+
_swapop(::typeof(/)) = \
373+
_swapop(::typeof(\)) = /
374+
375+
# ^ / POW doesn't have an equivalent builtin... Error for now.
376+
377+
_swapop(::typeof(iseq)) = iseq
378+
_swapop(::typeof(isne)) = isne
379+
380+
_swapop(::typeof(min)) = min
381+
_swapop(::typeof(max)) = max
382+
383+
_swapop(::typeof(isgt)) = isle
384+
_swapop(::typeof(isle)) = isgt
385+
386+
_swapop(::typeof(isge)) = islt
387+
_swapop(::typeof(islt)) = isge
388+
389+
_swapop(::typeof()) =
390+
_swapop(::typeof()) =
391+
392+
_swapop(::typeof(lxor)) = lxor
393+
_swapop(::typeof(xnor)) = xnor
394+
395+
_swapop(::typeof(==)) = ==
396+
_swapop(::typeof(!=)) = !=
397+
398+
_swapop(::typeof(>)) = <=
399+
_swapop(::typeof(<=)) = >
400+
_swapop(::typeof(<)) = >=
401+
_swapop(::typeof(>=)) = <
402+
403+
# I'm not going to bother with the trig/mod/sign/complex/etc. If you need them please open an issue.
404+
405+
_swapop(::typeof(|)) = |
406+
_swapop(::typeof(&)) = &
407+
_swapop(::typeof()) =
408+
_swapop(::typeof(bxnor)) = bxnor
409+
# bshift has no obvious equivalent in the builtins
410+
411+
_swapop(::typeof(firsti0)) = secondi0
412+
_swapop(::typeof(secondi0)) = firsti0
413+
_swapop(::typeof(firsti)) = secondi
414+
_swapop(::typeof(secondi)) = firsti
415+
416+
_swapop(::typeof(firstj0)) = secondj0
417+
_swapop(::typeof(secondj0)) = firstj0
418+
_swapop(::typeof(firstj)) = secondj
419+
_swapop(::typeof(secondj)) = firstj

0 commit comments

Comments
 (0)