Skip to content

Commit e0f5dcf

Browse files
committed
Prevent certain memory/store-dominated tasks from over-unrolling.
1 parent 696a0d7 commit e0f5dcf

File tree

8 files changed

+177
-75
lines changed

8 files changed

+177
-75
lines changed

benchmark/benchmarkflops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ sse_totwotuple(s::Integer) = ((3s) >> 1, s >> 1)
228228

229229
function sse_bench!(br, s, i)
230230
N, P = sse_totwotuple(s)
231-
y = rand(N); β = rand(P)
232-
X = randn(N, P)
233-
= similar(y)
231+
y = rand(N); β = rand(P);
232+
X = randn(N, P);
233+
= similar(y);
234234
lpblas = sse!(Xβ, y, X, β)
235235
n_gflop = 2e-9*(P*N + 2N)
236236
br[1,i] = n_gflop / @belapsed jOLSlp_avx($y, $X, $β)

src/determinestrategy.jl

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ function evaluate_cost_unroll(
147147
# hasintersection(reduceddependencies(op), nested_loop_syms) && return Inf
148148
rd = reduceddependencies(op)
149149
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return Inf
150+
if isstore(op) #TODO: DRY (this is repeated in evaluate_cost_tile)
151+
loadstoredeps = store_load_deps(op)
152+
if !isnothing(loadstoredeps)
153+
any(s -> (s loadstoredeps), nested_loop_syms) && return Inf
154+
end
155+
end
150156
included_vars[id] = true
151157
total_cost += iter * first(cost(ls, op, vectorized, Wshift, size_T))
152158
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
@@ -196,6 +202,7 @@ function unroll_no_reductions(ls, order, vectorized)
196202

197203
compute_rt = 0.0
198204
load_rt = 0.0
205+
store_rt = 0.0
199206
unrolled = last(order)
200207
if unrolled === vectorized && length(order) > 1
201208
unrolled = order[end-1]
@@ -207,13 +214,31 @@ function unroll_no_reductions(ls, order, vectorized)
207214
compute_rt += first(cost(ls, op, vectorized, Wshift, size_T))
208215
elseif isload(op)
209216
load_rt += first(cost(ls, op, vectorized, Wshift, size_T))
217+
elseif isstore(op)
218+
store_rt += first(cost(ls, op, vectorized, Wshift, size_T))
210219
end
211220
end
221+
# @show compute_rt, load_rt, store_rt
212222
# heuristic guess
213223
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
214-
rt = max(compute_rt, load_rt)
215-
# (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
216-
(iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
224+
memory_rt = load_rt + store_rt
225+
u = if compute_rt > memory_rt
226+
max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / compute_rt) ) ))
227+
elseif iszero(compute_rt)
228+
4
229+
else
230+
max(1, min(4, round(Int, 2compute_rt / memory_rt)))
231+
end
232+
# commented out here is to decide to align loops
233+
# if memory_rt > compute_rt && isone(u) && (length(order) > 1) && (last(order) === vectorized) && length(getloop(ls, last(order))) > 8W
234+
# ls.align_loops[] = findfirst(operations(ls)) do op
235+
# isstore(op) && dependson(op, unrolled)
236+
# end
237+
# end
238+
u, unrolled
239+
# rt = max(compute_rt, load_rt + store_rt)
240+
# # (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
241+
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
217242
end
218243
function determine_unroll_factor(
219244
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vectorized::Symbol
@@ -341,7 +366,8 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
341366
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
342367
if !(isfinite(u₂float) & isfinite(u₁float)) # brute force
343368
u₁low = u₂low = 1
344-
u₁high = u₂high = REGISTER_COUNT == 32 ? 10 : 6#8
369+
u₁high = iszero(X₂) ? 2 : (REGISTER_COUNT == 32 ? 8 : 6)
370+
u₂high = iszero(X₃) ? 2 : (REGISTER_COUNT == 32 ? 8 : 6)
345371
return solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
346372
end
347373
u₁low = floor(Int, u₁float)
@@ -632,13 +658,30 @@ function load_elimination_cost_factor!(
632658
false
633659
end
634660
end
635-
function loadintostore(ls::LoopSet, op::Operation)
636-
# isload(op) || return false # leads to bad behavior more than it helps
637-
# for opp ∈ operations(ls)
638-
# isstore(opp) && opp.ref == op.ref && return true
639-
# end
661+
# function loadintostore(ls::LoopSet, op::Operation)
662+
# isload(op) || return false # leads to bad behavior more than it helps
663+
# for opp ∈ operations(ls)
664+
# isstore(opp) && opp.ref == op.ref && return true
665+
# end
666+
# false
667+
# end
668+
function store_load_deps!(deps::Vector{Symbol}, op::Operation, compref = op.ref)
669+
for opp parents(op)
670+
foreach(ld -> ((ld deps) || push!(deps, ld)), loopdependencies(opp))
671+
foreach(ld -> ((ld deps) || push!(deps, ld)), reduceddependencies(opp))
672+
if isload(opp)
673+
(opp.ref == compref) && return true
674+
else
675+
store_load_deps!(deps, opp, compref) && return true
676+
end
677+
end
640678
false
641679
end
680+
function store_load_deps(op::Operation)
681+
isstore(op) || return nothing
682+
deps = copy(loopdependencies(op))
683+
store_load_deps!(deps, op) ? deps : nothing
684+
end
642685
function add_constant_offset_load_elmination_cost!(
643686
X, R, choose_to_inline, ls::LoopSet, op::Operation, iters, unrollsyms::UnrollSymbols, u₁reduces::Bool, u₂reduces::Bool, Wshift::Int, size_T::Int, opisininnerloop::Bool
644687
)
@@ -755,7 +798,16 @@ function evaluate_cost_tile(
755798
# it must also be a subset of defined symbols
756799
all(ld -> ld nested_loop_syms, loopdependencies(op)) || continue
757800
rd = reduceddependencies(op)
758-
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return 0,0,Inf,false
801+
if hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)]))
802+
# @show rd, op itersym, nested_loop_syms @view(nested_loop_syms[1:end-length(rd)])
803+
return 0,0,Inf,false
804+
end
805+
if isstore(op)
806+
loadstoredeps = store_load_deps(op)
807+
if !isnothing(loadstoredeps)
808+
any(s -> (s loadstoredeps), nested_loop_syms) && return 0,0,Inf,false
809+
end
810+
end
759811
included_vars[id] = true
760812
if isconstant(op)
761813
depends_on_u₁, depends_on_u₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym)
@@ -774,6 +826,7 @@ function evaluate_cost_tile(
774826
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
775827
end
776828
end
829+
irreducible_storecosts = 0.0
777830
for (id, op) enumerate(ops)
778831
iters[id] == -99.9 && continue
779832
opisininnerloop = descendentsininnerloop[id]
@@ -793,18 +846,26 @@ function evaluate_cost_tile(
793846
rt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
794847
prefetch_good_idea = true
795848
end
796-
rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
797-
# rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
849+
# rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
850+
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
798851
rto = rt
799852
rt *= iters[id]
853+
if isstore(op) & (!u₁reducesrt) & (!u₂reducesrt)
854+
irreducible_storecosts += rt
855+
end
800856
update_costs!(cost_vec, rt, u₁reducesrt, u₂reducesrt)
801857
update_costs!(reg_pressure, rp, u₁reducesrp, u₂reducesrp)
802858
end
803859
# @inbounds ((cost_vec[4] > 0) || ((cost_vec[2] > 0) & (cost_vec[3] > 0))) || return 0,0,Inf,false
804860
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
805861
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
806862
round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0)) : 0
807-
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, round_uᵢ)
863+
if irreducible_storecosts / sum(cost_vec) 0.25
864+
u₁, u₂ = (1, 1)
865+
ucost = unroll_cost(cost_vec, 1, 1, length(getloop(ls, u₁loopsym)), length(getloop(ls, u₂loopsym)))
866+
else
867+
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, round_uᵢ)
868+
end
808869
outer_reduct_penalty = length(ls.outer_reductions) * (u₁ + isodd(u₁))
809870
favor_bigger_u₂ = u₁ - u₂
810871
favor_smaller_vectorized = u₁v ? ( u₁ - u₂ ) : (u₂v ? ( u₂ - u₁ ) : 0 )

src/graphs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ struct LoopSet
250250
unrollspecification::Base.RefValue{UnrollSpecification}
251251
loadelimination::Base.RefValue{Bool}
252252
lssm::Base.RefValue{LoopStartStopManager}
253-
isbroadcast::Base.RefValue{Bool}
254253
vector_width::Base.RefValue{Int}
254+
# align_loops::Base.RefValue{Int}
255+
isbroadcast::Base.RefValue{Bool}
255256
mod::Symbol
256257
end
257258

@@ -340,7 +341,7 @@ function LoopSet(mod::Symbol)
340341
Matrix{Float64}(undef, 5, 2),
341342
Bool[], Bool[], Ref{UnrollSpecification}(),
342343
Ref(false), Ref{LoopStartStopManager}(),
343-
Ref(false), Ref(0), mod
344+
Ref(0), #=Ref(0),=# Ref(false), mod
344345
)
345346
end
346347

src/loopstartstopmanager.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMe
3838
# out[i] = out[j - offset]
3939
# continue
4040
# end
41-
if (!li[i]) || multiple_with_name(vptr(ar), ls.lssm[].uniquearrayrefs) || (iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))
41+
if (!li[i]) || multiple_with_name(vptr(ar), ls.lssm[].uniquearrayrefs) ||
42+
(iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))# ||
43+
# ((ls.align_loops[] > 0) && (first(names(ls)) == ind))
4244
out[i] = false
4345
elseif (isone(ii) && (first(looporder) === ind))
4446
out[i] = otherindexunrolled(ls, ind, ar)
@@ -82,7 +84,12 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
8284
uliv[i] = 0
8385
push!(gespinds.args, Expr(:call, lv(:Zero)))
8486
push!(offsetprecalc_descript.args, 0)
85-
elseif isbroadcast || ((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) || multiple_with_name(vptr(ar), allarrayrefs)) || (iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))
87+
elseif isbroadcast ||
88+
((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) ||
89+
multiple_with_name(vptr(ar), allarrayrefs)) ||
90+
(iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))# ||
91+
# ((ls.align_loops[] > 0) && (first(names(ls)) == ind))
92+
8693
# Not doing normal offset indexing
8794
uliv[i] = -findfirst(isequal(ind), looporder)::Int
8895
push!(gespinds.args, Expr(:call, lv(:Zero)))

src/lowering.jl

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ function lower_block(
144144
# loopsym = mangletiledsym(order[n], us, n)
145145
loopsym = order[n]
146146
# push!(blockq.args, incrementloopcounter(us, n, loopsym, UF))
147+
# if n > 1 || iszero(ls.align_loops[])
147148
push!(blockq.args, incrementloopcounter(ls, us, n, UF))
149+
# else
150+
# loopsym = names(ls)[n]
151+
# push!(blockq.args, Expr(:(=), loopsym, Expr(:call, lv(:vadd), loopsym, Symbol("##ALIGNMENT#STEP##"))))
152+
# end
148153
blockq
149154
end
150155

@@ -185,48 +190,43 @@ end
185190
# Expr(:block, loopiteratesatleastonce(loop), q)
186191
# end
187192
# end
188-
function lower_unroll_for_throughput(ls::LoopSet, us::UnrollSpecification, loop::Loop, loopsym::Symbol)
189-
UF = 4
190-
sl = startloop(ls, us, 1, UF)
191-
tcc = terminatecondition(ls, us, 1, false, 1)
192-
tcu = terminatecondition(ls, us, 1, false, UF)
193-
body = lower_block(ls, us, 1, false, 1)
194-
loopisstatic = isstaticloop(loop)
195-
tcu = loopisstatic ? tcu : expect(tcu)
196-
termcondu = gensym(:maybetermu)
197-
unrolledbody = Expr(:block)
198-
foreach(_ -> push!(unrolledbody.args, body), 1:UF)
199-
200-
# q = Expr(
201-
# :block,
202-
# Expr(:while, tcu, unrolledbody),
203-
# Expr(:while, tcc, body)
204-
# )
205-
# return Expr(:let, sl, q)
206-
207-
push!(unrolledbody.args, Expr(:(=), termcondu, tcu))
208-
209-
unrolledloop = Expr(
210-
:block,
211-
Expr(:while, termcondu, unrolledbody),
212-
Expr(:while, tcc, body)
213-
)
214-
215-
termcond = gensym(:maybeterm)
216-
singleloop = Expr(
217-
:block,
218-
Expr(:(=), termcond, true),
219-
Expr(:while, termcond, Expr(:block, body, Expr(:(=), termcond, tcc)))
220-
)
221-
222-
q = Expr(
223-
:block,
224-
assume(tcc),
225-
Expr(:(=), termcondu, tcu),
226-
Expr(:if, termcondu, unrolledloop, singleloop)
227-
)
228-
Expr(:let, sl, q)
229-
end
193+
# function lower_unroll_for_throughput(ls::LoopSet, us::UnrollSpecification, loop::Loop, loopsym::Symbol)
194+
# UF = 4
195+
# sl = startloop(ls, us, 1, UF)
196+
# tcc = terminatecondition(ls, us, 1, false, 1)
197+
# tcu = terminatecondition(ls, us, 1, false, UF)
198+
# body = lower_block(ls, us, 1, false, 1)
199+
# loopisstatic = isstaticloop(loop)
200+
# tcu = loopisstatic ? tcu : expect(tcu)
201+
# termcondu = gensym(:maybetermu)
202+
# unrolledbody = Expr(:block)
203+
# foreach(_ -> push!(unrolledbody.args, body), 1:UF)
204+
# # q = Expr(
205+
# # :block,
206+
# # Expr(:while, tcu, unrolledbody),
207+
# # Expr(:while, tcc, body)
208+
# # )
209+
# # return Expr(:let, sl, q)
210+
# push!(unrolledbody.args, Expr(:(=), termcondu, tcu))
211+
# unrolledloop = Expr(
212+
# :block,
213+
# Expr(:while, termcondu, unrolledbody),
214+
# Expr(:while, tcc, body)
215+
# )
216+
# termcond = gensym(:maybeterm)
217+
# singleloop = Expr(
218+
# :block,
219+
# Expr(:(=), termcond, true),
220+
# Expr(:while, termcond, Expr(:block, body, Expr(:(=), termcond, tcc)))
221+
# )
222+
# q = Expr(
223+
# :block,
224+
# assume(tcc),
225+
# Expr(:(=), termcondu, tcu),
226+
# Expr(:if, termcondu, unrolledloop, singleloop)
227+
# )
228+
# Expr(:let, sl, q)
229+
# end
230230

231231
function assume(ex)
232232
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:assume)), ex)
@@ -247,6 +247,22 @@ function loopiteratesatleastonce(loop::Loop, as::Bool = true)
247247
# as ? assume(comp) : expect(comp)
248248
assume(comp)
249249
end
250+
# @inline step_to_align(x, ::Val{W}) where {W} = step_to_align(pointer(x), Val{W}())
251+
# @inline step_to_align(x::Ptr{T}, ::Val{W}) where {W,T} = vsub(W, reinterpret(Int, x) & (W - 1))
252+
# function align_inner_loop_expr(ls::LoopSet, us::UnrollSpecification, loop::Loop)
253+
# alignincr = Symbol("##ALIGNMENT#STEP##")
254+
# looplength = gensym(:inner_loop_length)
255+
# pushpreamble!(ls, Expr(:(=), looplength, looplengthexpr(loop)))
256+
# vp = vptr(operations(ls)[ls.align_loops[]])
257+
# align_step = Expr(:call, :min, Expr(:call, lv(:step_to_align), vp, VECTORWIDTHSYMBOL), looplength)
258+
# Expr(
259+
# :block,
260+
# Expr(:(=), alignincr, align_step),
261+
# maskexpr(alignincr),
262+
# lower_block(ls, us, 1, true, 1)
263+
# )
264+
# end
265+
250266
function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool)
251267
usorig = ls.unrollspecification[]
252268
nisvectorized = isvectorized(us, n)
@@ -260,11 +276,14 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
260276
sl = startloop(ls, us, n)
261277
tc = terminatecondition(ls, us, n, inclmask, 1)
262278
body = lower_block(ls, us, n, inclmask, 1)
263-
isstatic = isstaticloop(loop)
264-
279+
# align_loop = isone(n) & (ls.align_loops[] > 0)
280+
isstatic = isstaticloop(loop)# & (!align_loop)
265281
if !isstatic && (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !inclmask
266282
tc = expect(tc)
267283
end
284+
# q = if align_loop
285+
# Expr(:block, align_inner_loop_expr(ls, us, loop), Expr(:while, tc, body))
286+
# elseif nisvectorized
268287
q = if nisvectorized
269288
# Expr(:block, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
270289
Expr(:block, Expr(:while, tc, body))
@@ -283,12 +302,15 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
283302
# push!(body.args, Expr(:||, expect(tc), Expr(:break)))
284303
# Expr(:block, Expr(:while, true, body))
285304
end
286-
287305
if nisvectorized
288306
# tc = terminatecondition(loop, us, n, loopsym, true, 1)
289307
tc = terminatecondition(ls, us, n, true, 1)
290308
body = lower_block(ls, us, n, true, 1)
291-
isone(num_loops(ls)) && pushfirst!(body.args, definemask(loop))
309+
if isone(num_loops(ls))
310+
pushfirst!(body.args, definemask(loop))
311+
# elseif align_loop
312+
# pushfirst!(body.args, definemask_for_alignment_cleanup(loop))
313+
end
292314
push!(q.args, Expr(:if, tc, body))
293315
end
294316
Expr(:block, Expr(:let, sl, q))
@@ -571,7 +593,7 @@ end
571593
function definemask(loop::Loop)
572594
if isstaticloop(loop)
573595
maskexpr(length(loop))
574-
elseif loop.startexact && loop.starthint == 1
596+
elseif loop.startexact && isone(loop.starthint)
575597
maskexpr(loop.stopsym)
576598
else
577599
lexpr = if loop.startexact
@@ -584,6 +606,14 @@ function definemask(loop::Loop)
584606
maskexpr(lexpr)
585607
end
586608
end
609+
function definemask_for_alignment_cleanup(loop::Loop)
610+
lexpr = if loop.stopexact
611+
Expr(:call, lv(:vsub), loop.stophint + 1, loop.itersym)
612+
else
613+
Expr(:call, lv(:vsub), Expr(:call, lv(:vadd), loop.stopsym, 1), loop.itersymbol)
614+
end
615+
maskexpr(lexpr)
616+
end
587617
function define_eltype_vec_width!(q::Expr, ls::LoopSet, vectorized)
588618
push!(q.args, Expr(:(=), ELTYPESYMBOL, determine_eltype(ls)))
589619
push!(q.args, Expr(:(=), VECTORWIDTHSYMBOL, determine_width(ls, vectorized)))

0 commit comments

Comments
 (0)