Skip to content

Commit e89f09e

Browse files
committed
More transition to 2-based tabs, and add cost check before deciding whether to simd a load
1 parent e366339 commit e89f09e

File tree

2 files changed

+92
-73
lines changed

2 files changed

+92
-73
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.69"
4+
version = "0.12.70"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_load.jl

Lines changed: 91 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -119,55 +119,73 @@ function pushbroadcast!(q::Expr, mvar::Symbol)
119119
push!(q.args, Expr(:(=), broadcastedname(mvar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, mvar)))
120120
end
121121

122+
function child_cost_untill_vectorized(op::Operation)
123+
isvectorized(op) && return 0.0
124+
c = 0.0
125+
for child children(op)
126+
if (!isvectorized(child) & iscompute(child))
127+
# FIXME: can double count
128+
c += COST[instruction(child).instr].scalar_reciprocal_throughput + child_cost_untill_vectorized(child)
129+
end
130+
end
131+
c
132+
end
133+
function vectorization_profitable(op::Operation)
134+
# if op is vectorized itself, return true
135+
isvectorized(op) && return true
136+
# otherwise, check if descendents until hitting a vectorized portion are expensive enough
137+
child_cost_untill_vectorized(op) 5
138+
end
139+
122140
function lower_load_no_optranslation!(
123141
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
124142
)
125-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, suffix = td
126-
loopdeps = loopdependencies(op)
127-
# @assert isvectorized(op)
128-
opu₁, opu₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, ls)
129-
u = ifelse(opu₁, u₁, 1)
130-
mvar = Symbol(variable_name(op, Core.ifelse(opu₂, suffix,-1)), '_', u)
131-
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
132-
if all(op.ref.loopedindex) && !rejectcurly(op)
133-
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset, ls)
134-
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
135-
add_memory_mask!(loadexpr, op, td, mask, ls)
136-
push!(loadexpr.args, falseexpr, rs) # unaligned load
137-
push!(q.args, Expr(:(=), mvar, loadexpr))
138-
elseif (u₁ > 1) & opu₁
139-
t = Expr(:tuple)
140-
sptrsym = sptr!(q, op)
141-
for u 1:u₁
142-
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, u-1, ls)
143-
loadexpr = Expr(:call, lv(:_vload), sptrsym, inds)
144-
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
145-
add_memory_mask!(loadexpr, op, td, domask, ls)
146-
push!(loadexpr.args, falseexpr, rs)
147-
push!(t.args, loadexpr)
148-
# push!(q.args, Expr(:(=), mvar, loadexpr))
149-
end
150-
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
151-
else
152-
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, 0, ls)
153-
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
154-
add_memory_mask!(loadexpr, op, td, mask, ls)
155-
push!(loadexpr.args, falseexpr, rs)
156-
push!(q.args, Expr(:(=), mvar, loadexpr))
157-
end
158-
if isvectorized(op)
159-
prefetchind = prefetchisagoodidea(ls, op, td)
160-
iszero(prefetchind) || add_prefetches!(q, ls, op, td, prefetchind)
161-
elseif any(isvectorized, children(op))
162-
pushbroadcast!(q, mvar)
143+
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, suffix = td
144+
loopdeps = loopdependencies(op)
145+
# @assert isvectorized(op)
146+
opu₁, opu₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, ls)
147+
u = ifelse(opu₁, u₁, 1)
148+
mvar = Symbol(variable_name(op, Core.ifelse(opu₂, suffix,-1)), '_', u)
149+
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
150+
if (all(op.ref.loopedindex) && !rejectcurly(op)) && vectorization_profitable(op)
151+
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset, ls)
152+
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
153+
add_memory_mask!(loadexpr, op, td, mask, ls)
154+
push!(loadexpr.args, falseexpr, rs) # unaligned load
155+
push!(q.args, Expr(:(=), mvar, loadexpr))
156+
elseif (u₁ > 1) & opu₁
157+
t = Expr(:tuple)
158+
sptrsym = sptr!(q, op)
159+
for u 1:u₁
160+
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, u-1, ls)
161+
loadexpr = Expr(:call, lv(:_vload), sptrsym, inds)
162+
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
163+
add_memory_mask!(loadexpr, op, td, domask, ls)
164+
push!(loadexpr.args, falseexpr, rs)
165+
push!(t.args, loadexpr)
166+
# push!(q.args, Expr(:(=), mvar, loadexpr))
163167
end
164-
nothing
168+
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
169+
else
170+
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, 0, ls)
171+
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
172+
add_memory_mask!(loadexpr, op, td, mask, ls)
173+
push!(loadexpr.args, falseexpr, rs)
174+
push!(q.args, Expr(:(=), mvar, loadexpr))
175+
end
176+
if isvectorized(op)
177+
prefetchind = prefetchisagoodidea(ls, op, td)
178+
iszero(prefetchind) || add_prefetches!(q, ls, op, td, prefetchind)
179+
elseif any(isvectorized, children(op))
180+
pushbroadcast!(q, mvar)
181+
end
182+
nothing
165183
end
166184
function indisvectorized(ls::LoopSet, ind::Symbol)
167-
for op operations(ls)
168-
((op.variable === ind) && isvectorized(op)) && return true
169-
end
170-
false
185+
for op operations(ls)
186+
((op.variable === ind) && isvectorized(op)) && return true
187+
end
188+
false
171189
end
172190
@inline firstunroll(vu::VecUnroll) = getfield(getfield(vu,:data),1,false)
173191
@inline firstunroll(x) = x
@@ -311,6 +329,7 @@ function _lower_load!(
311329
omop = offsetloadcollection(ls)
312330
@unpack opids, opidcollectionmap, batchedcollections, batchedcollectionmap = omop
313331
batchid, opind = batchedcollectionmap[identifier(op)]
332+
@show batchid, opind
314333
for (bid, oid) batchedcollectionmap # this relies on `for op ∈ ops` in codegen/operation_evaluation_order.jl
315334
if bid == batchid
316335
if oid == opind
@@ -335,38 +354,38 @@ function additive_vectorized_loopinductvar_only(op::Operation)
335354
end
336355
# Checks if we cannot use `Unroll`
337356
function rejectcurly(ls::LoopSet, op::Operation, td::UnrollArgs)
338-
@unpack u₁loopsym, vloopsym = td
339-
rejectcurly(ls, op, u₁loopsym, vloopsym)
357+
@unpack u₁loopsym, vloopsym = td
358+
rejectcurly(ls, op, u₁loopsym, vloopsym)
340359
end
341360
function rejectcurly(ls::LoopSet, op::Operation, u₁loopsym::Symbol, vloopsym::Symbol)
342-
indices = getindicesonly(op)
343-
li = op.ref.loopedindex
344-
AV = AU = false
345-
for (n,ind) enumerate(indices)
346-
# @show AU, op, n, ind, vloopsym, u₁loopsym
347-
if li[n]
348-
if ind === vloopsym
349-
AV && return true
350-
AV = true
351-
end
352-
if ind === u₁loopsym
353-
AU && return true
354-
AU = true
355-
end
356-
else
357-
opp = findop(parents(op), ind)
358-
# @show opp
359-
if isvectorized(opp)
360-
AV && return true
361-
AV = true
362-
end
363-
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp))
364-
AU && return true
365-
AU = true
366-
end
367-
end
361+
indices = getindicesonly(op)
362+
li = op.ref.loopedindex
363+
AV = AU = false
364+
for (n,ind) enumerate(indices)
365+
# @show AU, op, n, ind, vloopsym, u₁loopsym
366+
if li[n]
367+
if ind === vloopsym
368+
AV && return true
369+
AV = true
370+
end
371+
if ind === u₁loopsym
372+
AU && return true
373+
AU = true
374+
end
375+
else
376+
opp = findop(parents(op), ind)
377+
# @show opp
378+
if isvectorized(opp)
379+
AV && return true
380+
AV = true
381+
end
382+
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp))
383+
AU && return true
384+
AU = true
385+
end
368386
end
369-
false
387+
end
388+
false
370389
end
371390
function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true})
372391
strd = step(vloop)

0 commit comments

Comments
 (0)