Skip to content

Commit 075d731

Browse files
committed
fix broadcast regression
1 parent 75577e6 commit 075d731

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/operations/broadcasts.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,14 @@ modifying(::typeof(emul)) = emul!
103103
# for many operations. But for non-builtins we'd need an API of sorts.
104104
# To get around this for now we will require that Vectors be on the left
105105
# and transposed vectors be on the right.
106-
if left isa AbstractGBVector && right isa GBMatrixOrTranspose
106+
if left isa AbstractVector && right isa GBMatrixOrTranspose &&
107+
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
107108
return *(Diagonal(left), right, (any, f))
108109
end
109-
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractGBVector}
110+
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractVector} &&
111+
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
110112
return *(left, Diagonal(right), (any, f))
111113
end
112-
if left isa GBMatrixOrTranspose && right isa AbstractGBVector
113-
throw(ArgumentError(
114-
"Broadcasting a GBVector into a GBMatrix is only currently " *
115-
"supported with the GBVector on the left."))
116-
end
117-
if right isa GBMatrixOrTranspose && left isa Transpose{<:Any, <:AbstractGBVector}
118-
throw(ArgumentError(
119-
"Broadcasting a Transpose{<:Any, <:AbstractGBVector} into a GBMatrix" *
120-
" is only currently supported with the GBVector on the right."))
121-
end
122114
if left isa GBArrayOrTranspose && right isa GBArrayOrTranspose
123115
add = defaultadd(f)
124116
return add(left, right, f)
@@ -184,11 +176,19 @@ mutatingop(::typeof(apply)) = apply!
184176
# If they're further nested broadcasts we can't fuse them, so just copy.
185177
subargleft isa Broadcast.Broadcasted && (subargleft = copy(subargleft))
186178
subargright isa Broadcast.Broadcasted && (subargright = copy(subargright))
179+
if left isa AbstractVector && right isa GBMatrixOrTranspose &&
180+
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
181+
return *(Diagonal(left), right, (any, f); accum)
182+
end
183+
if left isa GBMatrixOrTranspose && right isa Transpose{<:Any, <:AbstractVector} &&
184+
!(size(left, 1) == size(right, 1) && size(left, 2) == size(right, 2))
185+
return *(left, Diagonal(right), (any, f); accum)
186+
end
187187
if subargleft isa StridedArray
188-
subargleft = pack(subargleft; fill = subargright isa GBArrayOrTranspose ? getfill(right) : nothing)
188+
subargleft = pack(subargleft; fill = subargright isa GBArrayOrTranspose ? getfill(right) : 0)
189189
end
190190
if subargright isa StridedArray
191-
subargright = pack(subargright; fill = subargleft isa GBArrayOrTranspose ? getfill(subargleft) : nothing)
191+
subargright = pack(subargright; fill = subargleft isa GBArrayOrTranspose ? getfill(subargleft) : 0)
192192
end
193193
if subargleft isa GBArrayOrTranspose && subargright isa GBArrayOrTranspose
194194
add = mutatingop(defaultadd(f))
@@ -208,10 +208,10 @@ mutatingop(::typeof(apply)) = apply!
208208
right = copy(right)
209209
end
210210
if left isa StridedArray
211-
left = pack(left; fill = right isa GBArrayOrTranspose ? getfill(right) : nothing)
211+
left = pack(left; fill = right isa GBArrayOrTranspose ? getfill(right) : 0)
212212
end
213213
if right isa StridedArray
214-
right = pack(right; fill = left isa GBArrayOrTranspose ? getfill(left) : nothing)
214+
right = pack(right; fill = left isa GBArrayOrTranspose ? getfill(left) : 0)
215215
end
216216
if left isa GBArrayOrTranspose && right isa GBArrayOrTranspose
217217
add = mutatingop(defaultadd(f))

0 commit comments

Comments
 (0)