Skip to content

Commit 81e63b8

Browse files
authored
Fix mutable allocations in n-ary multiplication (#326)
1 parent 7a87b02 commit 81e63b8

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/rewrite_generic.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,8 @@ function _rewrite_generic(stack::Expr, expr::Expr)
216216
push!(stack.args, :($root = $rhs))
217217
for i in 4:length(expr.args)
218218
arg, _ = _rewrite_generic(stack, expr.args[i])
219-
rhs = if is_mutable
220-
Expr(:call, operate!!, *, root, arg)
221-
else
222-
Expr(:call, operate, *, root, arg)
223-
end
219+
# It is always safe to modify `root` here
220+
rhs = Expr(:call, operate!!, *, root, arg)
224221
root = gensym()
225222
push!(stack.args, :($root = $rhs))
226223
end

test/rewrite_generic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,18 @@ function test_allocations_rewrite_unary_minus()
498498
return
499499
end
500500

501+
function test_allocations_rewrite_mult()
502+
x = BigFloat[1, 2, 3]
503+
y = MA.@rewrite(x[1] * x[2] * x[3], move_factors_into_sums = false)
504+
@test y == BigFloat(6)
505+
@test x == BigFloat[1, 2, 3]
506+
total = @allocated (x[1] * x[2]) * x[3]
507+
@test @allocated(
508+
MA.@rewrite(x[1] * x[2] * x[3], move_factors_into_sums = false),
509+
) < total
510+
return
511+
end
512+
501513
end # module
502514

503515
TestRewriteGeneric.runtests()

0 commit comments

Comments
 (0)