diff --git a/src/macro.jl b/src/macro.jl index 0f0aa0b..6a3f072 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -315,6 +315,13 @@ function parse_input(expr, store) if store.newarray && Z in store.arrays throw("can't create a new array $Z when this also appears on the right") + elseif Z in store.arrays + accumlike = intersect(store.leftind, store.shiftedind) + if !isempty(accumlike) + @warn "detected accumulation-like behaviour, this might be made an error" accumlike + store.avx = false + append!(store.unsafeleft, accumlike) # this means no threading, but not sure that's enough + end end end @@ -550,12 +557,7 @@ finishleftraw(leftraw, store) = map(enumerate(leftraw)) do (d,i) is_const(i) && store.newarray && (i != 1) && throw("can't fix indices on LHS when making a new array") - if isexpr(i, :$) - i.args[1] isa Symbol || throw("you can only interpolate single symbols, not $ex") - push!(store.scalars, i.args[1]) - return i.args[1] - - elseif isexpr(i, :call) && i.args[1] == :+ && + if isexpr(i, :call) && i.args[1] == :+ && length(i.args)==3 && i.args[3] == :_ # magic un-shift A[i+_, j] := ... i = primeindices(i.args)[2] i isa Symbol || throw("index ($i + _) is too complicated, sorry") @@ -578,6 +580,10 @@ finishleftraw(leftraw, store) = map(enumerate(leftraw)) do (d,i) end return ex # has primes dealt with + + elseif i isa Expr # deal with A[i,$j] = ... and also A[i,$j+1] = ... + ex = MacroTools_postwalk(dollarwalk(store), i) + return ex end i end diff --git a/test/parsing.jl b/test/parsing.jl index 340f62c..db34922 100644 --- a/test/parsing.jl +++ b/test/parsing.jl @@ -698,6 +698,23 @@ end @tullio x[i] := s[i,j] avx=false # Unexpected Pass with LV end + # https://github.com/mcabbott/Tullio.jl/issues/115 + # we must detect that cumsum isn't thread-safe, but should it be illegal? + x = rand(Int8, 10) .+ 0 + y = copy(x) + @tullio y[i] = y[i-1] + y[i] + @test y == cumsum(x) + + z = (y=copy(x),) + @tullio z.y[i] = z.y[i-1] + z.y[i] # version with field access + @test z.y == cumsum(x) + + # https://github.com/mcabbott/Tullio.jl/issues/115 + A = zeros(Int, 3,4) + for r in 2:4 + @tullio A[$r-1, c] = c + $r^2 avx=false # dollar on LHS with a shift + end + @test A == [c + r^2 for r in 2:4, c in 1:4] # get wrong answer with LV end @printline