Skip to content

Commit 8cb248c

Browse files
committed
Improvements to lowering and not creating excessive numbers of accumulation vectors when twice-unrolling.
1 parent a0df499 commit 8cb248c

16 files changed

+520
-299
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ export LowDimArray, stridedpointer, vectorizable,
2626
vmap, vmap!, vmapnt, vmapnt!, vmapntt, vmapntt!,
2727
vfilter, vfilter!
2828

29+
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")
30+
2931

3032
include("vectorizationbase_extensions.jl")
3133
include("predicates.jl")

src/add_compute.jl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,48 @@ end
100100
# end
101101
# false
102102
# end
103+
function add_reduced_deps!(op::Operation, reduceddeps::Vector{Symbol})
104+
# op.dependencies = copy(loopdependencies(op))
105+
# mergesetv!(loopdependencies(op), reduceddeps)
106+
reduceddepsop = reduceddependencies(op)
107+
if reduceddepsop === NODEPENDENCY
108+
op.reduced_deps = copy(reduceddeps)
109+
else
110+
mergesetv!(reduceddepsop, reduceddeps)
111+
end
112+
# reduceddepsop = reducedchildren(op)
113+
# if reduceddepsop === NODEPENDENCY
114+
# op.reduced_children = copy(reduceddeps)
115+
# else
116+
# mergesetv!(reduceddepsop, reduceddeps)
117+
# end
118+
nothing
119+
end
103120

104-
# function substitute_op_in_parents!(vparents::Vector{Operation}, replacer::Operation, replacee::Operation)
105-
# for i ∈ eachindex(vparents)
106-
# opp = vparents[i]
107-
# if opp === replacee
108-
# vparents[i] = replacer
109-
# else
110-
# substitute_op_in_parents!(parents(opp), replacer, replacee)
111-
# end
112-
# end
121+
# function substitute_op_in_parents!(
122+
# vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}
123+
# )
124+
# @show replacer replacee
125+
# #
126+
# substitute_op_in_parents_recurse!(vparents, replacer, replacee)
113127
# end
128+
function substitute_op_in_parents!(
129+
vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}
130+
)
131+
found = false
132+
for i eachindex(vparents)
133+
opp = vparents[i]
134+
if opp === replacee
135+
vparents[i] = replacer
136+
found = true
137+
else
138+
fopp = substitute_op_in_parents!(parents(opp), replacer, replacee, reduceddeps)
139+
fopp && add_reduced_deps!(opp, reduceddeps)
140+
found |= fopp
141+
end
142+
end
143+
found
144+
end
114145

115146

116147
function add_reduction_update_parent!(
@@ -157,8 +188,8 @@ function add_reduction_update_parent!(
157188
if instr.instr (:-, :vsub!, :vsub, :/, :vfdiv!, :vfidiv!)
158189
update_deps!(deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
159190
end
160-
# elseif !isouterreduction
161-
# substitute_op_in_parents!(vparents, reductinit, parent)
191+
elseif !isouterreduction
192+
substitute_op_in_parents!(vparents, reductinit, parent, reduceddeps)
162193
end
163194
update_reduction_status!(vparents, reduceddeps, name(reductinit))
164195
# this is the op added by add_compute

src/determinestrategy.jl

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ function add_constant_offset_load_elmination_cost!(
618618
end
619619
end
620620

621+
621622
# Just tile outer two loops?
622623
# But optimal order within tile must still be determined
623624
# as well as size of the tiles.
@@ -789,35 +790,93 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
789790
new_order, state = iter
790791
end
791792
end
793+
794+
795+
"""
796+
This function searches for unrolling combinations that will cause LoopVectorization to generate invalid code.
797+
798+
Currently, it is only searching for one scenario, based on how `isunrolled_sym` and lowering currently work.
799+
`isunrolledsym` tries to avoid the creation of excessive numbers of accumulation vectors in the case of reductions.
800+
If an unrolled loop isn't reduced, it will need separate vectors.
801+
But separate vectors for a reduced loop are not needed. Separate vectors will help to break up dependency chains,
802+
so you want to unroll at least one of the loops. However, reductions demand combining all the separate vectors,
803+
and each vector also eats a valuable register, so it's best to avoid excessive numbers these accumulation vectors.
804+
805+
806+
If a reduced op depends on both unrolled loops (u1 and u2), it will check over which of these it is reduced. If...
807+
neither: cannot avoid unrolling it along both
808+
one of them: don't unroll the reduced loop
809+
both of them: don't unroll along u2 (unroll along u1)
810+
811+
Now, a look at lowering:
812+
It interleaves u1-unrolled operations in an effort to improve superscalar parallelism,
813+
while u2-unrolled operations are lowered by block. E.g., op_u2id_u1id (as they're printed):
814+
815+
u2 = 0
816+
opa_0_0 = fa(...)
817+
opa_0_1 = fa(...)
818+
opa_0_2 = fa(...)
819+
opb_0_0 = fb(...)
820+
opb_0_1 = fb(...)
821+
opb_0_2 = fb(...)
822+
u2 += 1
823+
opa_1_0 = fa(...)
824+
opa_1_1 = fa(...)
825+
opa_1_2 = fa(...)
826+
opb_1_0 = fb(...)
827+
opb_1_1 = fb(...)
828+
opb_1_2 = fb(...)
829+
830+
what if `opa` vectors were not replicated across u1?
831+
opa_0_ = fa(...)
832+
opa_0_ = fa(...)
833+
opa_0_ = fa(...)
834+
835+
Then unless `fa` was taking the previous `opa_0_`s as an argument and updating them, this would be wrong, because it'd be overwriting the previous `opa_0_` values.
836+
"""
837+
function reject_candidate(op::Operation, u₁loopsym::Symbol, u₂loopsym::Symbol)
838+
if iscompute(op) && u₁loopsym reduceddependencies(op) && u₁loopsym loopdependencies(op)
839+
if u₂loopsym reduceddependencies(op) && !any(opp -> name(opp) === name(op), parents(op))
840+
return true
841+
end
842+
end
843+
false
844+
end
845+
846+
function reject_candidate(ls::LoopSet, u₁loopsym::Symbol, u₂loopsym::Symbol)
847+
for op operations(ls)
848+
reject_candidate(op, u₁loopsym, u₂loopsym) && return true
849+
end
850+
false
851+
end
852+
792853
function choose_tile(ls::LoopSet)
793854
lo = LoopOrders(ls)
794855
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
795856
bestu₁ = bestu₂ = best_vec = first(best_order) # filler
796-
new_order, state = iterate(lo) # right now, new_order === best_order
797857
u₁, u₂, lowest_cost = 0, 0, Inf
798-
nloops = length(new_order)
799-
while true
800-
for new_vec new_order # view to skip first
801-
for nt 1:nloops-1
802-
newu₂ = new_order[nt]
803-
for newu₁ @view(new_order[nt+1:end])
804-
u₁temp, u₂temp, cost_temp = evaluate_cost_tile(ls, new_order, newu₁, newu₂, new_vec)
805-
if cost_temp < lowest_cost
806-
lowest_cost = cost_temp
807-
u₁, u₂ = u₁temp, u₂temp
808-
best_vec = new_vec
809-
bestu₂ = newu₂
810-
bestu₁ = newu₁
811-
copyto!(best_order, new_order)
812-
save_tilecost!(ls)
813-
end
858+
for newu₂ lo.syms, newu₁ lo.syms#@view(new_order[nt+1:end])
859+
((newu₁ == newu₂) || reject_candidate(ls, newu₁, newu₂)) && continue
860+
new_order, state = iterate(lo) # right now, new_order === best_order
861+
while true
862+
for new_vec new_order # view to skip first
863+
u₁temp, u₂temp, cost_temp = evaluate_cost_tile(ls, new_order, newu₁, newu₂, new_vec)
864+
if cost_temp < lowest_cost
865+
lowest_cost = cost_temp
866+
u₁, u₂ = u₁temp, u₂temp
867+
best_vec = new_vec
868+
bestu₂ = newu₂
869+
bestu₁ = newu₁
870+
copyto!(best_order, new_order)
871+
save_tilecost!(ls)
814872
end
815873
end
874+
iter = iterate(lo, state)
875+
iter === nothing && break
876+
new_order, state = iter
816877
end
817-
iter = iterate(lo, state)
818-
iter === nothing && return best_order, bestu₁, bestu₂, best_vec, u₁, u₂, lowest_cost
819-
new_order, state = iter
820878
end
879+
best_order, bestu₁, bestu₂, best_vec, u₁, u₂, lowest_cost
821880
end
822881
# Last in order is the inner most loop
823882
function choose_order_cost(ls::LoopSet)

src/graphs.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,26 @@ function Loop(itersymbol::Symbol, start::Union{Int,Symbol}, stop::Union{Int,Symb
4444
end
4545
Base.length(loop::Loop) = 1 + loop.stophint - loop.starthint
4646
isstaticloop(loop::Loop) = loop.startexact & loop.stopexact
47-
function startloop(loop::Loop, isvectorized, W, itersymbol)
47+
function startloop(loop::Loop, isvectorized, itersymbol)
4848
startexact = loop.startexact
4949
if isvectorized
5050
if startexact
51-
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), W, loop.starthint))
51+
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.starthint))
5252
else
53-
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), W, loop.startsym))
53+
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
5454
end
5555
elseif startexact
5656
Expr(:(=), itersymbol, loop.starthint)
5757
else
5858
Expr(:(=), itersymbol, Expr(:call, lv(:unwrap), loop.startsym))
5959
end
6060
end
61-
function vec_looprange(loop::Loop, W::Symbol, UF::Int, mangledname::Symbol)
61+
function vec_looprange(loop::Loop, UF::Int, mangledname::Symbol)
6262
isunrolled = UF > 1
6363
incr = if isunrolled
64-
Expr(:call, lv(:valmuladd), W, UF, -2)
64+
Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, UF, -2)
6565
else
66-
Expr(:call, lv(:valsub), W, 2)
66+
Expr(:call, lv(:valsub), VECTORWIDTHSYMBOL, 2)
6767
end
6868
if loop.stopexact # split for type stability
6969
Expr(:call, lv(:scalar_less), mangledname, Expr(:call, :-, loop.stophint, incr))
@@ -80,22 +80,22 @@ function looprange(loop::Loop, incr::Int, mangledname::Symbol)
8080
end
8181
end
8282
function terminatecondition(
83-
loop::Loop, us::UnrollSpecification, n::Int, W::Symbol, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
83+
loop::Loop, us::UnrollSpecification, n::Int, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
8484
)
8585
if !isvectorized(us, n)
8686
looprange(loop, UF, mangledname)
8787
elseif inclmask
8888
looprange(loop, 1, mangledname)
8989
else
90-
vec_looprange(loop, W, UF, mangledname) # may not be u₂loop
90+
vec_looprange(loop, UF, mangledname) # may not be u₂loop
9191
end
9292
end
93-
function incrementloopcounter(us::UnrollSpecification, n::Int, W::Symbol, mangledname::Symbol, UF::Int = unrollfactor(us, n))
93+
function incrementloopcounter(us::UnrollSpecification, n::Int, mangledname::Symbol, UF::Int = unrollfactor(us, n))
9494
if isvectorized(us, n)
9595
if UF == 1
96-
Expr(:(=), mangledname, Expr(:call, lv(:valadd), W, mangledname))
96+
Expr(:(=), mangledname, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname))
9797
else
98-
Expr(:+=, mangledname, Expr(:call, lv(:valmul), W, UF))
98+
Expr(:+=, mangledname, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UF))
9999
end
100100
else
101101
Expr(:+=, mangledname, UF)
@@ -158,8 +158,6 @@ struct LoopSet
158158
reg_pres::Matrix{Float64}
159159
included_vars::Vector{Bool}
160160
place_after_loop::Vector{Bool}
161-
W::Symbol
162-
T::Symbol
163161
mod::Symbol
164162
end
165163

@@ -240,10 +238,9 @@ end
240238
# false
241239
# end
242240

243-
244241
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays
245242

246-
function LoopSet(mod::Symbol, W = Symbol("##Wvecwidth##"), T = Symbol("##Tloopeltype##"))# = :LoopVectorization)
243+
function LoopSet(mod::Symbol)
247244
LoopSet(
248245
Symbol[], [0], Loop[],
249246
Dict{Symbol,Operation}(),
@@ -259,8 +256,7 @@ function LoopSet(mod::Symbol, W = Symbol("##Wvecwidth##"), T = Symbol("##Tloopel
259256
ArrayReferenceMeta[],
260257
Matrix{Float64}(undef, 4, 2),
261258
Matrix{Float64}(undef, 4, 2),
262-
Bool[], Bool[],
263-
W, T, mod
259+
Bool[], Bool[], mod
264260
)
265261
end
266262

0 commit comments

Comments
 (0)