Skip to content

Commit f90a2b4

Browse files
committed
Improve perf of some reductions.
1 parent 5cbab62 commit f90a2b4

File tree

5 files changed

+108
-39
lines changed

5 files changed

+108
-39
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.9.10"
4+
version = "0.9.11"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, data,
99
mask, pick_vector_width_val, MM,
1010
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, vzero,
1111
Zero, maybestaticrange, offsetprecalc, lazymul,
12-
maybestaticfirst, maybestaticlast, scalar_less, gep, gesp, pointerforcomparison, NativeTypes,
12+
maybestaticfirst, maybestaticlast, scalar_less, scalar_greaterequal, gep, gesp, pointerforcomparison, NativeTypes,
1313
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, vadd, vsub, vmul,
1414
relu, stridedpointer, StridedPointer, StridedBitPointer, AbstractStridedPointer,
1515
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min,

src/graphs.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,28 @@ function vec_looprange(loopmax, UF::Int, mangledname::Symbol, ptrcomp::Bool)
105105
end
106106
end
107107
function vec_looprange(loopmax, UF::Int, mangledname, W)
108-
incr = if isone(UF)
109-
Expr(:call, lv(:vsub), W, staticexpr(1))
108+
if isone(UF)
109+
compexpr = subexpr(loopmax, W)
110110
else
111-
Expr(:call, lv(:vsub), Expr(:call, lv(:vmul), W, UF), staticexpr(1))
111+
compexpr = subexpr(loopmax, Expr(:call, lv(:vmul), W, UF))
112112
end
113-
compexpr = subexpr(loopmax, incr)
114-
Expr(:call, :<, mangledname, compexpr)
113+
Expr(:call, :, mangledname, compexpr)
115114
end
115+
# function vec_looprange(loopmax, UF::Int, mangledname, W)
116+
# incr = if isone(UF)
117+
# Expr(:call, lv(:vsub), W, staticexpr(1))
118+
# else
119+
# Expr(:call, lv(:vsub), Expr(:call, lv(:vmul), W, UF), staticexpr(1))
120+
# end
121+
# compexpr = subexpr(loopmax, incr)
122+
# Expr(:call, :<, mangledname, compexpr)
123+
# end
116124

117125
function looprange(stopcon, incr::Int, mangledname)
118126
if iszero(incr)
119127
Expr(:call, :, mangledname, stopcon)
128+
elseif isone(incr)
129+
Expr(:call, :<, mangledname, stopcon)
120130
else
121131
Expr(:call, :, mangledname, subexpr(stopcon, incr))
122132
end

src/lower_store.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, U
5656
push!(instrexpr.args, Symbol(toreduct, u + 1))
5757
push!(q.args, Expr(:(=), Symbol(toreduct, (u>>>1)), instrexpr))
5858
end
59-
else
59+
elseif 2Uh < Uh2
6060
for u Uh:Uh2-2
6161
tru = Symbol(toreduct, u - Uh)
6262
instrexpr = callexpr(instr)
@@ -71,6 +71,13 @@ function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, U
7171
push!(instrexpr.args, Symbol(toreduct, u))
7272
push!(q.args, Expr(:(=), tru, instrexpr))
7373
end
74+
else
75+
for u 0:Uh2-Uh - 1
76+
instrexpr = callexpr(instr)
77+
push!(instrexpr.args, Symbol(toreduct, u))
78+
push!(instrexpr.args, Symbol(toreduct, u + Uh))
79+
push!(q.args, Expr(:(=), Symbol(toreduct, u), instrexpr))
80+
end
7481
end
7582
end
7683
function reduce_range!(q::Expr, ls::LoopSet, Ulow::Int, Uhigh::Int)

src/lowering.jl

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ function allinteriorunrolled(ls::LoopSet, us::UnrollSpecification, N)
291291
unroll_total 8
292292
end
293293

294-
function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool)
294+
function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool, initialize::Bool = true, maxiters::Int=-1)
295295
usorig = ls.unrollspecification[]
296296
nisvectorized = isvectorized(us, n)
297297
loopsym = names(ls)[n]
@@ -301,7 +301,7 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
301301
# # return lower_llvm_unroll(ls, us, n, loop)
302302
# end
303303
# sl = startloop(loop, nisvectorized, loopsym)
304-
sl = startloop(ls, us, n)
304+
305305
tc = terminatecondition(ls, us, n, inclmask, 1)
306306
body = lower_block(ls, us, n, inclmask, 1)
307307
# align_loop = isone(n) & (ls.align_loops[] > 0)
@@ -319,11 +319,11 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
319319
foreach(_ -> push!(q.args, body), 1:(length(loop) ÷ W))
320320
elseif nisvectorized
321321
# Expr(:block, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
322-
q = Expr(:block, Expr(:while, tc, body))
322+
q = Expr(:block, Expr(maxiters == 1 ? :if : :while, tc, body))
323323
else
324324
termcond = gensym(:maybeterm)
325325
push!(body.args, Expr(:(=), termcond, tc))
326-
q = Expr(:block, Expr(:(=), termcond, true), Expr(:while, termcond, body))
326+
q = Expr(:block, Expr(:(=), termcond, true), Expr(maxiters == 1 ? :if : :while, termcond, body))
327327
# Expr(:block, Expr(:while, expect(tc), body))
328328
# Expr(:block, assume(tc), Expr(:while, tc, body))
329329
# push!(body.args, Expr(:&&, expect(Expr(:call, :!, tc)), Expr(:break)))
@@ -346,7 +346,11 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
346346
push!(q.args, Expr(:if, tc, body))
347347
end
348348
end
349-
Expr(:block, Expr(:let, sl, q))
349+
if initialize
350+
Expr(:let, startloop(ls, us, n), q)
351+
else
352+
q
353+
end
350354
end
351355
function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool)
352356
UF = unrollfactor(us, n)
@@ -389,23 +393,49 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
389393
end
390394
remblock = Expr(:block)
391395
(nisvectorized && (UFt > 0) && isone(num_loops(ls))) && push!(remblock.args, definemask(loop))
396+
unroll_cleanup = true
392397
else
393398
remblock = init_remblock(loop, ls.lssm[], n)#loopsym)
399+
# unroll_cleanup = Ureduct > 0 || (nisunrolled ? (u₂ > 1) : (u₁ > 1))
400+
# remblock = unroll_cleanup ? init_remblock(loop, ls.lssm[], n)#loopsym) : Expr(:block)
394401
q = Expr(:while, tc, body)
395402
end
396403
q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt); is logic relying on twos-complement bad?
397-
UF_cleanup = UF - Ureduct
398-
us_cleanup = nisunrolled ? UnrollSpecification(us, UF_cleanup, u₂) : UnrollSpecification(us, u₁, UF_cleanup)
399-
Expr(
400-
:block,
401-
add_upper_outer_reductions(ls, q, Ureduct, UF, loop, vectorized),
402-
Expr(
403-
# :if, terminatecondition(loop, us, n, loopsym, inclmask, UF_cleanup),
404-
:if, terminatecondition(ls, us, n, inclmask, UF_cleanup),
405-
lower_block(ls, us_cleanup, n, inclmask, UF_cleanup)
406-
),
407-
remblock
408-
)
404+
add_cleanup = true
405+
if isone(Ureduct)
406+
UF_cleanup = 1
407+
if nisvectorized
408+
blockhead = :while
409+
else
410+
blockhead = if UF == 2
411+
if loopisstatic
412+
add_cleanup = UFt == 1
413+
:block
414+
else
415+
:if
416+
end
417+
else
418+
:while
419+
end
420+
UFt = 0
421+
end
422+
elseif 2Ureduct < UF
423+
UF_cleanup = 2
424+
blockhead = :while
425+
else
426+
UF_cleanup = UF - Ureduct
427+
blockhead = :if
428+
end
429+
_q = Expr(:block, add_upper_outer_reductions(ls, q, Ureduct, UF, loop, vectorized, nisvectorized))
430+
if add_cleanup
431+
cleanup_expr = Expr(blockhead)
432+
blockhead === :block || push!(cleanup_expr.args, terminatecondition(ls, us, n, inclmask, UF_cleanup))
433+
us_cleanup = nisunrolled ? UnrollSpecification(us, UF_cleanup, u₂) : UnrollSpecification(us, u₁, UF_cleanup)
434+
push!(cleanup_expr.args, lower_block(ls, us_cleanup, n, inclmask, UF_cleanup))
435+
push!(_q.args, cleanup_expr)
436+
end
437+
UFt > 0 && push!(_q.args, remblock)
438+
_q
409439
elseif remfirst
410440
numiters = length(loop) ÷ UF
411441
if numiters > 2
@@ -440,10 +470,14 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
440470
Expr( :block, q, remblock )
441471
end
442472
if !iszero(UFt)
473+
# if unroll_cleanup
443474
while true
444475
ust = nisunrolled ? UnrollSpecification(us, UFt, u₂) : UnrollSpecification(us, u₁, UFt)
445476
newblock = lower_block(ls, ust, n, remmask, UFt)
446477
if (UFt UF - 1 + nisvectorized) || UFt == Ureduct || loopisstatic
478+
if isone(num_loops(ls)) && isone(UFt) && isone(Ureduct)
479+
newblock = Expr(:block, definemask(loop), newblock)
480+
end
447481
push!(remblock.args, newblock)
448482
break
449483
end
@@ -459,6 +493,11 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
459493
end
460494
UFt += 1
461495
end
496+
# else
497+
# ust = nisunrolled ? UnrollSpecification(us, 1, u₂) : UnrollSpecification(us, u₁, 1)
498+
# # newblock = lower_block(ls, ust, n, remmask, 1)
499+
# push!(remblock.args, lower_no_unroll(ls, ust, n, inclmask, false, UF-1))
500+
# end
462501
end
463502
Expr(:block, Expr(:let, sl, q))
464503
end
@@ -529,26 +568,37 @@ end
529568
function initialize_outer_reductions!(ls::LoopSet, Umin::Int, Umax::Int, vectorized::Symbol, suffix::Union{Symbol,Nothing} = nothing)
530569
initialize_outer_reductions!(ls.preamble, ls, Umin, Umax, vectorized, suffix)
531570
end
532-
function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::Int, unrolledloop::Loop, vectorized::Symbol)
533-
ifq = Expr(:block)
534-
initialize_outer_reductions!(ifq, ls, Ulow, Uhigh, vectorized)
535-
push!(ifq.args, loopq)
536-
reduce_range!(ifq, ls, Ulow, Uhigh)
537-
loopbuffer = Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, Uhigh)
538-
comparison = if isstaticloop(unrolledloop)
539-
Expr(:call, lv(:scalar_less), length(unrolledloop), loopbuffer)
571+
function add_upper_comp_check(unrolledloop, loopbuffer)
572+
if isstaticloop(unrolledloop)
573+
Expr(:call, lv(:scalar_greaterequal), length(unrolledloop), loopbuffer)
540574
elseif unrolledloop.startexact
541575
if isone(unrolledloop.starthint)
542-
Expr(:call, lv(:scalar_less), unrolledloop.stopsym, loopbuffer)
576+
Expr(:call, lv(:scalar_greaterequal), unrolledloop.stopsym, loopbuffer)
543577
else
544-
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stopsym, unrolledloop.starthint-1), loopbuffer)
578+
Expr(:call, lv(:scalar_greaterequal), Expr(:call, lv(:vsub), unrolledloop.stopsym, unrolledloop.starthint-1), loopbuffer)
545579
end
546580
elseif unrolledloop.stopexact
547-
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stophint+1, unrolledloop.startsym), loopbuffer)
581+
Expr(:call, lv(:scalar_greaterequal), Expr(:call, lv(:vsub), unrolledloop.stophint+1, unrolledloop.startsym), loopbuffer)
548582
else# both are given by symbols
549-
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stopsym, Expr(:call,lv(:vsub),unrolledloop.startsym, staticexpr(1))), loopbuffer)
583+
Expr(:call, lv(:scalar_greaterequal), Expr(:call, lv(:vsub), unrolledloop.stopsym, Expr(:call,lv(:vsub),unrolledloop.startsym, staticexpr(1))), loopbuffer)
584+
end
585+
end
586+
function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::Int, unrolledloop::Loop, vectorized::Symbol, reductisvectorized::Bool)
587+
ifq = Expr(:block)
588+
initialize_outer_reductions!(ifq, ls, Ulow, Uhigh, vectorized)
589+
push!(ifq.args, loopq)
590+
_Ulow = Uhigh >>> 1; _Uhigh = Uhigh
591+
while _Ulow > Ulow
592+
reduce_range!(ifq, ls, _Ulow, _Uhigh)
593+
_Uhigh = _Ulow
594+
_Ulow >>>= 1
595+
end
596+
reduce_range!(ifq, ls, Ulow, _Uhigh)
597+
ncomparison = if reductisvectorized
598+
add_upper_comp_check(unrolledloop, Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, Uhigh))
599+
else
600+
add_upper_comp_check(unrolledloop, Uhigh)
550601
end
551-
ncomparison = Expr(:call, :!, comparison)
552602
Expr(:if, ncomparison, ifq)
553603
end
554604
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
@@ -760,7 +810,9 @@ function calc_Ureduct(ls::LoopSet, us::UnrollSpecification)
760810
elseif u₂ == -1
761811
loopisstatic = isstaticloop(getloop(ls, names(ls)[u₁loopnum]))
762812
loopisstatic &= ((vectorizedloopnum != u₁loopnum) | (!iszero(ls.vector_width[])))
763-
loopisstatic ? u₁ : min(u₁, 4)
813+
# loopisstatic ? u₁ : min(u₁, 4) # much worse than the other two options, don't use this one
814+
loopisstatic ? u₁ : (u₁ 4 ? 2 : 1)
815+
# loopisstatic ? u₁ : 1
764816
else
765817
8#u₂#u₁
766818
# elseif num_loops(ls) == u₁loopnum

0 commit comments

Comments
 (0)