Skip to content

Commit 076ac7b

Browse files
committed
Allow += to udpate vectors in loops.
1 parent 382b190 commit 076ac7b

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/LoopVectorization.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
2323
:exp2 => (:SLEEFPirates, :exp2),
2424
:exp10 => (:SLEEFPirates, :exp10),
2525
:expm1 => (:SLEEFPirates, :expm1),
26-
:sqrt => (:SLEEFPirates, :sqrt), # faster than sqrt_fast
26+
# :sqrt => (:SLEEFPirates, :sqrt), # faster than sqrt_fast
27+
:sqrt => (:SIMDPirates, :sqrt), # faster than sqrt_fast
2728
:rsqrt => (:SIMDPirates, :rsqrt),
2829
:cbrt => (:SLEEFPirates, :cbrt_fast),
2930
:asin => (:SLEEFPirates, :asin_fast),
@@ -216,7 +217,7 @@ end
216217
remr = gensym(:remreps)
217218
q = quote
218219
$Nsym = $N
219-
($Qsym, $remsym) = $(num_vector_load_expr(:($mod.LoopVectorization), N, W<<log2unroll))
220+
($Qsym, $remsym) = $(num_vector_load_expr(:($mod.LoopVectorization), Nsym, W<<log2unroll))
220221
end
221222
if unroll_factor > 1
222223
push!(q.args, :($remr = $remsym >>> $Wshift))
@@ -373,7 +374,17 @@ end
373374
_spirate(prewalk(expr) do x
374375
# @show x
375376
# @show main_body
377+
if @capture(x, A_[i__] += B_)
378+
x = :($A[$(i...)] = $B + $A[$(i...)])
379+
elseif @capture(x, A_[i__] -= B_)
380+
x = :($A[$(i...)] = $A[$(i...)] - $B)
381+
elseif @capture(x, A_[i__] *= B_)
382+
x = :($A[$(i...)] = $B * $A[$(i...)])
383+
elseif @capture(x, A_[i__] /= B_)
384+
x = :($A[$(i...)] = $A[$(i...)] / $B)
385+
end
376386
if @capture(x, A_[i_] = B_) || @capture(x, setindex!(A_, B_, i_))
387+
# println("Made it.")
377388
if A keys(indexed_expressions)
378389
# pA = esc(gensym(A))
379390
# pA = esc(Symbol(:p,A))
@@ -439,7 +450,6 @@ end
439450
else
440451
pA = indexed_expressions[A]
441452
end
442-
443453
## check to see if we are to do a vector load or a broadcast
444454
if i == declared_iter_sym
445455
load_expr = :($mod.vload($V, $pA + $itersym ))

0 commit comments

Comments
 (0)