Skip to content

Commit 4d3fef0

Browse files
committed
Improve behavior when unrolling reductions.
1 parent 7825a8f commit 4d3fef0

File tree

1 file changed

+67
-17
lines changed

1 file changed

+67
-17
lines changed

src/LoopVectorization.jl

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,16 @@ whenever size(A) is known at compile time. Seems to be the case for Julia 1.1.
144144
# ntuple(i -> (i-1) * stride, Val(N))
145145
# end
146146

147+
function replace_syms_i(expr, set, i)
148+
postwalk(expr) do ex
149+
if ex isa Symbol && ex set
150+
return Symbol(ex, :_, i)
151+
else
152+
return ex
153+
end
154+
end
155+
end
156+
147157
@noinline function vectorize_body(N, Tsym::Symbol, uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, mod = :LoopVectorization)
148158
if Tsym == :Float32
149159
vectorize_body(N, Float32, uf, n, body, vecdict, VType, mod)
@@ -247,9 +257,16 @@ end
247257
$main_body
248258
$itersym += $W
249259
end
250-
# Eventually we should unroll reductions as well. Should also check if the compiler ever does that for us (doubt it)
251-
for u in 1:unroll_factor
260+
if unroll_factor == 1
261+
push!(unrolled_loop_body_expr.args, unrolled_loop_body_iter)
262+
else
263+
ulb = unrolled_loop_body_iter
264+
rep_syms = Set(values(reduction_symbols))
265+
unrolled_loop_body_iter = replace_syms_i(ulb, rep_syms, 0)
252266
push!(unrolled_loop_body_expr.args, unrolled_loop_body_iter)
267+
for u in 1:unroll_factor-1
268+
push!(unrolled_loop_body_expr.args, replace_syms_i(ulb, rep_syms, u))
269+
end
253270
end
254271
unadjitersym = gensym(:unadjitersym)
255272
if loop_max_expr isa Integer && loop_max_expr <= 1
@@ -296,22 +313,55 @@ end
296313
end
297314
end
298315
### now we walk the body to look for reductions
299-
for ((sym,op),gsym) reduction_symbols
300-
if op == :+ || op == :-
301-
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,zero($T))))
302-
elseif op == :* || op == :/
303-
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
316+
if unroll_factor == 1
317+
for ((sym,op),gsym) reduction_symbols
318+
if op == :+ || op == :-
319+
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,zero($T))))
320+
elseif op == :* || op == :/
321+
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
322+
end
323+
if op == :+
324+
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
325+
elseif op == :-
326+
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
327+
elseif op == :*
328+
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
329+
elseif op == :/
330+
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
331+
end
304332
end
305-
if op == :+
306-
# push!(q.args, :(@show $sym, $gsym))
307-
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
308-
# push!(q.args, :(@show $sym, $gsym))
309-
elseif op == :-
310-
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
311-
elseif op == :*
312-
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
313-
elseif op == :/
314-
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
333+
else
334+
for ((sym,op),gsym_base) reduction_symbols
335+
for uf 0:unroll_factor-1
336+
gsym = Symbol(gsym_base, :_, uf)
337+
if op == :+ || op == :-
338+
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,zero($T))))
339+
elseif op == :* || op == :/
340+
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
341+
end
342+
end
343+
func = ((op == :*) | (op == :/)) ? :($mod.vmul) : :($mod.vadd)
344+
uf_new = unroll_factor
345+
while uf_new > 1
346+
uf_new, uf_prev = uf_new >> 1, uf_new
347+
for uf 0:uf_new - 1 # reduce half divisible by two
348+
push!(q.args, Expr(:(=), Symbol(gsym_base, :_, uf), Expr(:call, func, Symbol(gsym_base, :_, 2uf), Symbol(gsym_base, :_, 2uf + 1))))
349+
end
350+
uf_firstrem = 2uf_new
351+
for uf uf_firstrem:uf_prev - 1
352+
push!(q.args, Expr(:(=), Symbol(gsym_base, :_, uf - uf_firstrem), Expr(:call, func, Symbol(gsym_base, :_, uf - uf_firstrem), Symbol(gsym_base, :_, uf))))
353+
end
354+
end
355+
gsym = Symbol(gsym_base, :_, 0)
356+
if op == :+
357+
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
358+
elseif op == :-
359+
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
360+
elseif op == :*
361+
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
362+
elseif op == :/
363+
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
364+
end
315365
end
316366
end
317367
push!(q.args, nothing)

0 commit comments

Comments
 (0)