Skip to content

Commit d3701d5

Browse files
committed
Add constant offset indices based dependency check
1 parent 37527ad commit d3701d5

File tree

9 files changed

+506
-191
lines changed

9 files changed

+506
-191
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.24"
4+
version = "0.12.25"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_load.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -331,18 +331,20 @@ end
331331
function _lower_load!(
332332
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool} = indices_calculated_by_pointer_offsets(ls, op.ref)
333333
)
334-
if rejectinterleave(op)
335-
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
336-
else
337-
omop = offsetloadcollection(ls)
338-
batchid, opind = omop.batchedcollectionmap[identifier(op)]
339-
if opind == 1
340-
collectionid, copind = omop.opidcollectionmap[identifier(op)]
341-
opidmap = offsetloadcollection(ls).opids[collectionid]
342-
idsformap = omop.batchedcollections[batchid]
343-
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
344-
end
334+
if rejectinterleave(op)
335+
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
336+
else
337+
omop = offsetloadcollection(ls)
338+
@unpack opids, opidcollectionmap, batchedcollections, batchedcollectionmap = omop
339+
batchid, opind = batchedcollectionmap[identifier(op)]
340+
if opind == 1
341+
collectionid, copind = opidcollectionmap[identifier(op)]
342+
opidmap = opids[collectionid]
343+
idsformap = batchedcollections[batchid]
344+
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
345345
end
346+
end
347+
return nothing
346348
end
347349
function additive_vectorized_loopinductvar_only(op::Operation)
348350
isvectorized(op) || return true
@@ -388,18 +390,21 @@ function rejectcurly(ls::LoopSet, op::Operation, u₁loopsym::Symbol, vloopsym::
388390
false
389391
end
390392
function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true})
391-
vloopsym = vloop.itersymbol; strd = step(vloop)
392-
isknown(strd) || return true
393-
# TODO: reject if there is a vectorized !loopedindex index
394-
for (li,ind) zip(op.ref.loopedindex,getindicesonly(op))
395-
li && continue
396-
for indop operations(ls)
397-
if (name(indop) === ind) && isvectorized(indop)
398-
additive_vectorized_loopinductvar_only(indop) || return true # so that it is `MM`
399-
end
400-
end
393+
strd = step(vloop)
394+
isknown(strd) || return true
395+
# TODO: reject if there is a vectorized !loopedindex index
396+
indices = getindicesonly(op); li = op.ref.loopedindex
397+
for i eachindex(li)
398+
li[i] && continue
399+
ind = indices[i]
400+
for indop operations(ls)
401+
if (name(indop) === ind) && isvectorized(indop)
402+
additive_vectorized_loopinductvar_only(indop) || return true # so that it is `MM`
403+
end
401404
end
402-
(first(getindices(op)) === vloopsym) && (length(idsformap) first(getstrides(op)) * gethint(strd))
405+
end
406+
vloopsym = vloop.itersymbol;
407+
(first(getindices(op)) === vloopsym) && (length(idsformap) first(getstrides(op)) * gethint(strd))
403408
end
404409
# function lower_load_collection_manual_u₁unroll!(
405410
# q::Expr, ls::LoopSet, opidmap::Vector{Int},

src/codegen/lower_memory_common.jl

Lines changed: 71 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -185,80 +185,83 @@ end
185185

186186
# interleave: `0` means `false`, positive means literal, negative means multiplier
187187
function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, mask::Bool, interleave::Int=0)
188-
u₁loopsym = u₁loop.itersymbol
189-
vloopsym = vloop.itersymbol
190-
indices = getindicesonly(op)
191-
vstep = step(vloop)
192-
li = op.ref.loopedindex
193-
# @assert all(loopedindex)
194-
# @unpack u₁, u₁loopsym, vloopsym = td
195-
AV = AU = -1
196-
for (n,ind) enumerate(indices)
197-
# @show AU, op, n, ind, vloopsym, u₁loopsym
198-
if li[n]
199-
if ind === vloopsym
200-
@assert AV == -1 # FIXME: these asserts should be replaced with checks that prevent using `unrolled_curly` in these cases (also to be reflected in cost modeling, to avoid those)
201-
AV = n
202-
end
203-
if ind === u₁loopsym
204-
@assert AU == -1
205-
AU = n
206-
end
207-
else
208-
opp = findop(parents(op), ind)
209-
# @show opp
210-
if isvectorized(opp)
211-
@assert AV == -1
212-
AV = n
213-
end
214-
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp) || (ind === u₁loopsym))
215-
@assert AU == -1
216-
AU = n
217-
end
218-
end
219-
end
220-
AU == -1 && throw(LoopError("Failed to find $(u₁loopsym) in args of $(repr(op))."))
221-
vecnotunrolled = AU != AV
222-
conditional_memory_op = isconditionalmemop(op)
223-
if mask || conditional_memory_op
224-
M = one(UInt)
225-
# `isu₁unrolled(last(parents(op)))` === is condop unrolled?
226-
# isu₁unrolled(last(parents(op)))
227-
if vecnotunrolled || conditional_memory_op || (interleave > 0) # mask all
228-
M = (M << u₁) - M
229-
else # mask last
230-
M <<= (u₁ - 1)
188+
u₁loopsym = u₁loop.itersymbol
189+
vloopsym = vloop.itersymbol
190+
indices = getindicesonly(op)
191+
vstep = step(vloop)
192+
li = op.ref.loopedindex
193+
# @assert all(loopedindex)
194+
# @unpack u₁, u₁loopsym, vloopsym = td
195+
AV = AU = -1
196+
for (n,ind) enumerate(indices)
197+
# @show AU, op, n, ind, vloopsym, u₁loopsym
198+
if li[n]
199+
if ind === vloopsym
200+
@assert AV == -1 # FIXME: these asserts should be replaced with checks that prevent using `unrolled_curly` in these cases (also to be reflected in cost modeling, to avoid those)
201+
AV = n
202+
end
203+
if ind === u₁loopsym
204+
if AU -1
205+
u₁loopsym === CONSTANTZEROINDEX && continue
206+
throw(ArgumentError("Two of the same index $ind?"))
231207
end
208+
AU = n
209+
end
232210
else
233-
M = zero(UInt)
211+
opp = findop(parents(op), ind)
212+
# @show opp
213+
if isvectorized(opp)
214+
@assert AV == -1
215+
AV = n
216+
end
217+
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp) || (ind === u₁loopsym))
218+
@assert AU == -1
219+
AU = n
220+
end
221+
end
222+
end
223+
AU == -1 && throw(LoopError("Failed to find $(u₁loopsym) in args of $(repr(op))."))
224+
vecnotunrolled = AU != AV
225+
conditional_memory_op = isconditionalmemop(op)
226+
if mask || conditional_memory_op
227+
M = one(UInt)
228+
# `isu₁unrolled(last(parents(op)))` === is condop unrolled?
229+
# isu₁unrolled(last(parents(op)))
230+
if vecnotunrolled || conditional_memory_op || (interleave > 0) # mask all
231+
M = (M << u₁) - M
232+
else # mask last
233+
M <<= (u₁ - 1)
234234
end
235-
@assert isknown(step(u₁loop)) "Unrolled loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here"
236-
if AV > 0
237-
@assert isknown(step(vloop)) "Vectorized loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here."
238-
X = convert(Int, getstrides(op)[AV])
239-
X *= gethint(step(vloop))
240-
intvecsym = :(Int($VECTORWIDTHSYMBOL))
241-
if interleave > 0
242-
Expr(:curly, lv(:Unroll), AU, interleave, u₁, AV, intvecsym, M, X)
243-
elseif interleave < 0
244-
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, -interleave))))
245-
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
235+
else
236+
M = zero(UInt)
237+
end
238+
@assert isknown(step(u₁loop)) "Unrolled loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here"
239+
if AV > 0
240+
@assert isknown(step(vloop)) "Vectorized loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here."
241+
X = convert(Int, getstrides(op)[AV])
242+
X *= gethint(step(vloop))
243+
intvecsym = :(Int($VECTORWIDTHSYMBOL))
244+
if interleave > 0
245+
Expr(:curly, lv(:Unroll), AU, interleave, u₁, AV, intvecsym, M, X)
246+
elseif interleave < 0
247+
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, -interleave))))
248+
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
249+
else
250+
if vecnotunrolled
251+
# Expr(:call, Expr(:curly, lv(:Unroll), AU, 1, u₁, AV, intvecsym, M, 1), ind)
252+
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, AV, intvecsym, M, X)
253+
else
254+
if isone(X)
255+
Expr(:curly, lv(:Unroll), AU, intvecsym, u₁, AV, intvecsym, M, X)
246256
else
247-
if vecnotunrolled
248-
# Expr(:call, Expr(:curly, lv(:Unroll), AU, 1, u₁, AV, intvecsym, M, 1), ind)
249-
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, AV, intvecsym, M, X)
250-
else
251-
if isone(X)
252-
Expr(:curly, lv(:Unroll), AU, intvecsym, u₁, AV, intvecsym, M, X)
253-
else
254-
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, X))))
255-
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
256-
end
257-
end
257+
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, X))))
258+
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
258259
end
259-
else
260-
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, 0, 1, M, 1)
260+
end
261261
end
262+
else
263+
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, 0, 1, M, 1)
264+
end
262265
end
263266
function unrolledindex(op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}, ls::LoopSet)
264267
@unpack u₁, u₁loopsym, u₁loop, vloop = td

src/codegen/lowering.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ function lower!(
2626
end
2727
end
2828
end
29-
29+
function isu₂invalidstorereorder(ls::LoopSet, us::UnrollSpecification)
30+
us.u₂ == -1 ? false : ls.validreorder[ls.loopordermap[us.u₂loopnum]] 0x03
31+
end
3032
function lower_block(
3133
ls::LoopSet, us::UnrollSpecification, n::Int, mask::Bool, UF::Int
3234
)
@@ -40,6 +42,7 @@ function lower_block(
4042
u₁ = n == u₁loopnum ? UF : u₁
4143
dontmaskfirsttiles = mask && vloopnum == u₂loopnum
4244
blockq = Expr(:block)
45+
cannot_reorder_u₂ = isu₂invalidstorereorder(ls, us)
4346
for prepost 1:2
4447
# !u₁ && !u₂
4548
lower!(blockq, ops[1,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true)
@@ -56,22 +59,30 @@ function lower_block(
5659
lower_tiled_store!(blockq, opsv1, opsv2, ls, unrollsyms, u₁, u₂, mask)
5760
else
5861
for store (false,true)
62+
if cannot_reorder_u₂
63+
nstores = 0# break
64+
lowernonstore = lowerstore = true
65+
else
66+
lowerstore = store; lowernonstore = !store
67+
end
5968
for t 0:u₂-1
6069
# !u₁ && u₂
61-
lower!(blockq, opsv1, ls, unrollsyms, u₁, u₂, t, mask & !(dontmaskfirsttiles & (t < u₂ - 1)), !store, store)
70+
lower!(blockq, opsv1, ls, unrollsyms, u₁, u₂, t, mask & !(dontmaskfirsttiles & (t < u₂ - 1)), lowernonstore, lowerstore)
6271
if iszero(t) && !store # u₁ && !u₂
6372
# for u ∈ 0:u₁-1
64-
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true)
73+
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, lowernonstore, lowerstore)
6574
# end
6675
end
6776
# u₁ && u₂
6877
# for u ∈ 0:u₁-1
69-
lower!(blockq, opsv2, ls, unrollsyms, u₁, u₂, t, mask & !(dontmaskfirsttiles & (t < u₂ - 1)), !store, store)
78+
lower!(blockq, opsv2, ls, unrollsyms, u₁, u₂, t, mask & !(dontmaskfirsttiles & (t < u₂ - 1)), lowernonstore, lowerstore)
7079
# end
7180
end
7281
nstores == 0 && break
7382
end
7483
end
84+
elseif cannot_reorder_u₂
85+
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true)
7586
else
7687
# for u ∈ 0:u₁-1 # u₁ && !u₂
7788
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, false)

0 commit comments

Comments
 (0)