Skip to content

Commit 9ab66f2

Browse files
authored
Fix broadcast for mis-matched arrays (#159)
1 parent 788364d commit 9ab66f2

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/broadcast.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,31 @@ function broadcast_mutability(x, op, args::Vararg{Any,N}) where {N}
8989
return broadcast_mutability(typeof(x), op, typeof.(args)...)
9090
end
9191

92+
_checked_size(s, x::AbstractArray) = size(x) == s
93+
_checked_size(::Any, ::Any) = true
94+
_checked_size(::Any, ::Tuple{}) = true
95+
function _checked_size(s, x::Tuple)
96+
return _checked_size(s, x[1]) && _checked_size(s, Base.tail(x))
97+
end
98+
99+
# This method is a slightly tricky one:
100+
#
101+
# If the elements in the broadcast are different sized arrays, weird things can
102+
# happen during broadcasting since we'll either need to return a different size
103+
# to `x`, or multiple copies of an argument will be used for different parts of
104+
# `x`. To simplify, let's just return `IsNotMutable` if the sizes are different,
105+
# which will be slower but correct.
106+
function broadcast_mutability(
107+
x::AbstractArray,
108+
op,
109+
args::Vararg{Any,N},
110+
) where {N}
111+
if !_checked_size(size(x), args)::Bool
112+
return IsNotMutable()
113+
end
114+
return broadcast_mutability(typeof(x), op, typeof.(args)...)
115+
end
116+
92117
broadcast_mutability(::Type) = IsNotMutable()
93118

94119
"""

test/broadcast.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const MA = MutableArithmetics
1919
alloc_test(() -> MA.broadcast!!(+, a, b), 0)
2020
alloc_test(() -> MA.broadcast!!(+, a, c), 0)
2121
end
22+
2223
@testset "BigInt" begin
2324
x = BigInt(1)
2425
y = BigInt(2)
@@ -34,3 +35,13 @@ end
3435
alloc_test(() -> MA.broadcast!!(+, a, b), 30 * sizeof(Int))
3536
alloc_test(() -> MA.broadcast!!(+, a, c), 0)
3637
end
38+
39+
@testset "broadcast_issue_158" begin
40+
x, y = BigInt[2 3], BigInt[2 3; 3 4]
41+
@test MA.@rewrite(x .+ y) == x .+ y
42+
@test MA.@rewrite(x .- y) == x .- y
43+
@test MA.@rewrite(y .+ x) == y .+ x
44+
@test MA.@rewrite(y .- x) == y .- x
45+
@test MA.@rewrite(y .* x) == y .* x
46+
@test MA.@rewrite(x .* y) == x .* y
47+
end

0 commit comments

Comments
 (0)