Skip to content

Commit f730695

Browse files
committed
Major progress on overhauling LoopVectorization's lowering of indexing. Still need to finnish handling of loop remainders (this branch will currently segfault if not evenly divisble by the unrolling factors).
1 parent f926fdd commit f730695

12 files changed

+277
-182
lines changed

src/LoopVectorization.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, UnPack, OffsetArrays
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valmulsub, valadd, valsub, _MM,
6-
maybestaticlength, maybestaticsize, staticm1, staticp1, subsetview, vzero, stridedpointer_for_broadcast,
7-
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
6+
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, subsetview, vzero, stridedpointer_for_broadcast,
7+
Static, Zero, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
99
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
1010
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater, noalias!, gesp, gepbyte
@@ -47,9 +47,10 @@ include("add_compute.jl")
4747
include("add_constants.jl")
4848
include("add_ifelse.jl")
4949
include("determinestrategy.jl")
50+
include("loopstartstopmanager.jl")
5051
include("lower_compute.jl")
5152
include("lower_constant.jl")
52-
include("zero.jl")
53+
# include("zero.jl")
5354
include("lower_memory_common.jl")
5455
include("lower_load.jl")
5556
include("lower_store.jl")

src/add_constants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementby
3636
temp = gensym(:intermediateconstref)
3737
vloadcall = Expr(:call, lv(:vload), mpref.mref.ptr)
3838
if length(getindices(op)) > 0
39-
push!(vloadcall.args, mem_offset(op, UnrollArgs(0, Symbol(""), Symbol(""), Symbol(""), nothing), false, false))
39+
push!(vloadcall.args, mem_offset(op, UnrollArgs(0, Symbol(""), Symbol(""), Symbol(""), 0, nothing), false, false))
4040
end
4141
pushpreamble!(ls, Expr(:(=), temp, vloadcall))
4242
pushpreamble!(ls, op, temp)

src/determinestrategy.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ end
454454
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
455455
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
456456
(vectorized == u₁loopsym || vectorized == u₂loopsym) && return false, false
457-
(u₁loopsym loopdependencies(op) && u₂loopsym loopdependencies(op)) || return false, false
457+
(isu₁unrolled(op) && isu₂unrolled(op)) || return false, false
458458
istranslation = false
459459
inds = getindices(op); li = op.ref.loopedindex
460460
translationplus = false
@@ -597,6 +597,7 @@ function evaluate_cost_tile(
597597
N = length(order)
598598
@assert N 2 "Cannot tile merely $N loops!"
599599
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
600+
cacheunrolled!(ls, u₁loopsym, u₂loopsym, vectorized)
600601
# u₂loopsym = order[1]
601602
# u₁loopsym = order[2]
602603
ops = operations(ls)
@@ -647,8 +648,8 @@ function evaluate_cost_tile(
647648
rd = reduceddependencies(op)
648649
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return 0,0,Inf,false
649650
included_vars[id] = true
650-
depends_on_u₁ = u₁loopsym loopdependencies(op)
651-
depends_on_u₂ = u₂loopsym loopdependencies(op)
651+
depends_on_u₁ = isu₁unrolled(op)
652+
depends_on_u₂ = isu₂unrolled(op)
652653
# cost is reduced by unrolling u₁ if it is interior to u₁loop (true if either u₁reached, or if depends on u₂ [or u₁]) and doesn't depend on u₁
653654
reduced_by_unrolling[1,id] = (u₁reached | depends_on_u₂) & !depends_on_u₁
654655
reduced_by_unrolling[2,id] = (u₂reached | depends_on_u₁) & !depends_on_u₂

src/graphs.jl

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ struct UnrollArgs{T <: Union{Nothing,Int}}
99
u₁loopsym::Symbol
1010
u₂loopsym::Symbol
1111
vectorized::Symbol
12+
u₂max::Int
1213
suffix::T
1314
end
14-
function UnrollArgs(U::Int, unrollsyms::UnrollSymbols, suffix)
15+
function UnrollArgs(u₁::Int, unrollsyms::UnrollSymbols, u₂max::Int, suffix)
1516
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
16-
UnrollArgs(U, u₁loopsym, u₂loopsym, vectorized, suffix)
17+
UnrollArgs(u₁, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix)
1718
end
1819
function UnrollArgs(ua::UnrollArgs, u::Int)
19-
@unpack u₁loopsym, u₂loopsym, vectorized, suffix = ua
20-
UnrollArgs(u, u₁loopsym, u₂loopsym, vectorized, suffix)
20+
@unpack u₁loopsym, u₂loopsym, vectorized, u₂max, suffix = ua
21+
UnrollArgs(u, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix)
2122
end
2223
# UnrollSymbols(ua::UnrollArgs) = UnrollSymbols(ua.u₁loopsym, ua.u₂loopsym, ua.vectorized)
2324

@@ -69,14 +70,8 @@ function Loop(itersymbol::Symbol, start::Union{Int,Symbol}, stop::Union{Int,Symb
6970
end
7071
Base.length(loop::Loop) = 1 + loop.stophint - loop.starthint
7172
isstaticloop(loop::Loop) = loop.startexact & loop.stopexact
72-
function startloop(loop::Loop, isvectorized, itersymbol)
73+
function startloop(loop::Loop, itersymbol)
7374
startexact = loop.startexact
74-
# if isvectorized
75-
# if startexact
76-
# Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.starthint))
77-
# else
78-
# Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
79-
# end
8075
if startexact
8176
Expr(:(=), itersymbol, loop.starthint - 1)
8277
else
@@ -97,31 +92,44 @@ addexpr(ex::Number, incr::Number) = ex + incr
9792
subexpr(ex, incr) = Expr(:call, lv(:vsub), ex, incr)
9893
subexpr(ex::Number, incr::Number) = ex - incr
9994
subexpr(ex, incr::Number) = addexpr(ex, -incr)
100-
function vec_looprange(loop::Loop, UF::Int, mangledname::Symbol)
95+
96+
staticmulincr(ptr, incr) = Expr(:call, lv(:staticmul), Expr(:call, :eltype, ptr), incr)
97+
callpointer(sym) = Expr(:call, :pointer, sym)
98+
function vec_looprange(loopmax, UF::Int, mangledname::Symbol, ptrcomp::Bool)
10199
incr = if isone(UF)
102100
Expr(:call, lv(:valsub), VECTORWIDTHSYMBOL, 1)
103101
else
104102
Expr(:call, lv(:valmulsub), VECTORWIDTHSYMBOL, UF, 1)
105103
end
106-
if loop.stopexact # split for type stability
107-
Expr(:call, :<, mangledname, subexpr(loop.stophint, incr))
104+
incr = ptrcomp ? staticmulincr(mangledname, incr) : incr
105+
compexpr = subexpr(loopmax, incr)
106+
if ptrcomp
107+
Expr(:call, :<, callpointer(mangledname), compexpr)
108108
else
109-
Expr(:call, :<, mangledname, subexpr(loop.stopsym, incr))
109+
Expr(:call, :<, mangledname, compexpr)
110110
end
111111
end
112112

113-
function looprange(stopcon, incr::Int, mangledname::Symbol)
113+
function looprange(stopcon, incr::Int, mangledname::Symbol, ptrcomp::Bool)
114114
incr = 1 - incr
115115
if iszero(incr)
116-
Expr(:call, :<, mangledname, stopcon)
117-
elseif isone(incr)
118-
Expr(:call, :, mangledname, stopcon)
116+
if ptrcomp
117+
Expr(:call, :<, callpointer(mangledname), stopcon)
118+
else
119+
Expr(:call, :<, mangledname, stopcon)
120+
end
121+
elseif ptrcomp
122+
Expr(:call, :<, callpointer(mangledname), addexpr(stopcon, staticmulincr(mangledname, incr)))
119123
else
120-
Expr(:call, :<, mangledname, addexpr(stopcon, incr))
124+
if isone(incr)
125+
Expr(:call, :, mangledname, stopcon)
126+
else
127+
Expr(:call, :<, mangledname, addexpr(stopcon, incr))
128+
end
121129
end
122130
end
123131
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
124-
loop.stopexact ? looprange(loop.stophint, incr, mangledname) : looprange(loop.stopsym, incr, mangledname)
132+
loop.stopexact ? looprange(loop.stophint, incr, mangledname, false) : looprange(loop.stopsym, incr, mangledname, false)
125133
end
126134
function terminatecondition(
127135
loop::Loop, us::UnrollSpecification, n::Int, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
@@ -130,22 +138,36 @@ function terminatecondition(
130138
looprange(loop, UF, mangledname)
131139
elseif inclmask
132140
looprange(loop, 1, mangledname)
141+
elseif loop.stopexact
142+
vec_looprange(loop.stophint, UF, mangledname, false) # may not be u₂loop
133143
else
134-
vec_looprange(loop, UF, mangledname) # may not be u₂loop
144+
vec_looprange(loop.stopsym, UF, mangledname, false) # may not be u₂loop
135145
end
136146
end
137147
function incrementloopcounter(us::UnrollSpecification, n::Int, mangledname::Symbol, UF::Int = unrollfactor(us, n))
138148
if isvectorized(us, n)
139149
if UF == 1
140150
Expr(:(=), mangledname, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname))
141-
# Expr(:(=), mangledname, Expr(:macrocall, Symbol("@show"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname)))
142151
else
143152
Expr(:(=), mangledname, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, UF, mangledname))
144153
end
145154
else
146155
Expr(:(=), mangledname, Expr(:call, lv(:vadd), mangledname, UF))
147156
end
148157
end
158+
function incrementloopcounter!(q, us::UnrollSpecification, n::Int, UF::Int = unrollfactor(us, n))
159+
if isvectorized(us, n)
160+
if UF == 1
161+
push!(q.args, Expr(:call, lv(:unwrap), VECTORWIDTHSYMBOL))
162+
else
163+
push!(q.args, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UF))
164+
end
165+
elseif isone(UF)
166+
push!(q.args, Expr(:call, Expr(:curly, lv(:Static), UF)))
167+
else
168+
push!(q.args, UF)
169+
end
170+
end
149171

150172
# load/compute/store × isunrolled × istiled × pre/post loop × Loop number
151173
struct LoopOrder <: AbstractArray{Vector{Operation},5}
@@ -294,19 +316,22 @@ function LoopSet(mod::Symbol)
294316
)
295317
end
296318

319+
cacheunrolled!(ls::LoopSet, u₁loop, u₂loop, vectorized) = foreach(op -> setunrolled!(op, u₁loop, u₂loop, vectorized), operations(ls))
320+
297321
num_loops(ls::LoopSet) = length(ls.loops)
298322
function oporder(ls::LoopSet)
299323
N = length(ls.loop_order.loopnames)
300324
reshape(ls.loop_order.oporder, (2,2,2,N))
301325
end
302326
names(ls::LoopSet) = ls.loop_order.loopnames
327+
reversenames(ls::LoopSet) = ls.loop_order.bestorder
303328
function getloopid(ls::LoopSet, s::Symbol)::Int
304329
for (loopnum,sym) enumerate(ls.loopsymbols)
305330
s === sym && return loopnum
306331
end
307332
end
308333
getloop(ls::LoopSet, s::Symbol) = ls.loops[getloopid(ls, s)]
309-
getloop(ls::LoopSet, i::Integer) = ls.loops[i]
334+
# getloop(ls::LoopSet, i::Integer) = ls.loops[i]
310335
getloopsym(ls::LoopSet, i::Integer) = ls.loopsymbols[i]
311336
Base.length(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
312337

src/lower_load.jl

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
function lower_load_scalar!(
2-
q::Expr, op::Operation, ua::UnrollArgs, umin::Int = 0
2+
q::Expr, op::Operation, ua::UnrollArgs, umin::Int, inds_calc_by_ptr_offset::Vector{Bool}
33
)
44
loopdeps = loopdependencies(op)
55
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = ua
66
@assert vectorized loopdeps
77
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loop, u₂loop, suffix)
88
mvar = variable_name(op, suffix)
9-
opu₁ = u₁loopsym loopdeps
10-
unrolled = (opu₁ || u₂loopsym loopdeps)
11-
ptr = unrolled ? offset_refname(op, ua) : refname(op)
9+
opu₁ = isu₁unrolled(op)
10+
ptr = refname(op)
1211
U = opu₁ ? u₁ : 1
1312
if instruction(op).instr !== :conditionalload
1413
for u umin:U-1
1514
varname = varassignname(mvar, u, opu₁)
1615
td = UnrollArgs(ua, u)
17-
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, unrolled))))
16+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, inds_calc_by_ptr_offset))))
1817
end
1918
else
2019
opu₂ = !isnothing(suffix) && u₂loopsym loopdeps
@@ -24,20 +23,20 @@ function lower_load_scalar!(
2423
condsym = varassignname(condvar, u, condu₁)
2524
varname = varassignname(mvar, u, u₁loopsym loopdependencies(op))
2625
td = UnrollArgs(ua, u)
27-
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, unrolled))
26+
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, inds_calc_by_ptr_offset))
2827
cload = Expr(:if, condsym, load, Expr(:call, :zero, Expr(:call, :eltype, ptr)))
2928
push!(q.args, Expr(:(=), varname, cload))
3029
end
3130
end
3231
nothing
3332
end
3433
function pushvectorload!(
35-
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, vectorized::Symbol, mask, u₁unrolled::Bool, unrolled::Bool
34+
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, vectorized::Symbol, mask, u₁unrolled::Bool, inds_calc_by_ptr_offset::Vector{Bool}
3635
)
3736
@unpack u₁, u₁loopsym, u₂loopsym, suffix = td
38-
ptr = unrolled ? offset_refname(op, td) : refname(op)
37+
ptr = refname(op)
3938
vecnotunrolled = vectorized !== u₁loopsym
40-
name, mo = name_memoffset(var, op, td, u₁unrolled, unrolled)
39+
name, mo = name_memoffset(var, op, td, u₁unrolled, inds_calc_by_ptr_offset)
4140
instrcall = Expr(:call, lv(:vload), ptr, mo)
4241

4342
iscondstore = instruction(op).instr === :conditionalload
@@ -67,20 +66,27 @@ function pushvectorload!(
6766
end
6867
function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
6968
# return false
70-
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = td
69+
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix = td
7170
vectorized loopdependencies(op) || return 0
7271
u₂loopsym === Symbol("##undefined##") && return 0
7372
dontskip = (64 ÷ VectorizationBase.REGISTER_SIZE) - 1
74-
(!isnothing(suffix) && (vectorized === u₂loopsym) && !iszero(suffix & dontskip)) && return 0
75-
innermostloopsym = last(ls.loop_order.bestorder)
73+
# u₂loopsym is vectorized
74+
# u₁vectorized = vectorized === u₁loopsym
75+
u₂vectorized = vectorized === u₂loopsym
76+
(!isnothing(suffix) && u₂vectorized && !iszero(suffix & dontskip)) && return 0
77+
innermostloopsym = first(names(ls))
7678
loopedindex = op.ref.loopedindex
7779
if length(loopedindex) > 1 && first(loopedindex)
7880
indices = getindices(op)
7981
if first(indices) === vectorized && last(indices) === innermostloopsym
82+
# We want at least 4 reuses per load
83+
uses = ifelse(isu₁unrolled(op), 1, u₁)
84+
uses = ifelse(isu₂unrolled(op), uses, uses * u₂max)
85+
uses < 4 && return 0
8086
innermostloopindv = findall(map(isequal(innermostloopsym), getindices(op)))
8187
isone(length(innermostloopindv)) || return 0
8288
innermostloopind = first(innermostloopindv)
83-
if prod(s -> length(getloop(ls, s)), @view(indices[1:innermostloopind-1])) 120 && length(getloop(ls, innermostloopind)) 120
89+
if prod(s -> length(getloop(ls, s)), @view(indices[1:innermostloopind-1])) 120 && length(getloop(ls, innermostloopsym)) 120
8490
if op.ref.ref.offsets[innermostloopind] < 120
8591
for opp operations(ls)
8692
iscompute(opp) && load_constrained(opp, u₁loopsym, u₂loopsym) && return 0
@@ -97,9 +103,9 @@ function lower_load_vectorized!(
97103
)
98104
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = td
99105
loopdeps = loopdependencies(op)
100-
@assert vectorized loopdeps
101-
opu₁ = u₁loopsym loopdeps
102-
unrolled = opu₁ || u₂loopsym loopdeps
106+
@assert isvectorized(op)
107+
opu₁ = isu₁unrolled(op)
108+
inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets(ls, op.ref)
103109
if opu₁
104110
umin = umin
105111
U = u₁
@@ -111,13 +117,13 @@ function lower_load_vectorized!(
111117
var = variable_name(op, suffix)
112118
for u umin:U-1
113119
td = UnrollArgs(td, u)
114-
pushvectorload!(q, op, var, td, U, vectorized, mask, opu₁, unrolled)
120+
pushvectorload!(q, op, var, td, U, vectorized, mask, opu₁, inds_calc_by_ptr_offset)
115121
end
116122
prefetchind = prefetchisagoodidea(ls, op, td)
117123
if !iszero(prefetchind)
118124
dontskip = (64 ÷ VectorizationBase.REGISTER_SIZE) - 1
119-
ptr = offset_refname(op, td)
120-
innermostloopsym = last(ls.loop_order.bestorder)
125+
ptr = refname(op)
126+
innermostloopsym = first(names(ls))
121127
us = ls.unrollspecification[]
122128
prefetch_multiplier = 4
123129
prefetch_distance = u₁loopsym === innermostloopsym ? us.u₁ : ( u₂loopsym === innermostloopsym ? us.u₂ : 1 )
@@ -128,7 +134,7 @@ function lower_load_vectorized!(
128134
# for u ∈ umin:min(umin,U-1)
129135
(u₁loopsym === vectorized && !iszero(u & dontskip)) && continue
130136
offsets[prefetchind] = inner_offset + prefetch_distance
131-
mo = mem_offset_u(op, UnrollArgs(td, u), true)
137+
mo = mem_offset_u(op, UnrollArgs(td, u), inds_calc_by_ptr_offset)
132138
instrcall = Expr(:call, lv(:prefetch0), ptr, mo)
133139
push!(q.args, instrcall)
134140
end
@@ -158,7 +164,7 @@ function lower_load!(
158164
varnew = variable_name(op, suffix)
159165
varold = variable_name(operations(ls)[id], suffix + mno)
160166
opold = operations(ls)[id]
161-
if u₁loopsym loopdependencies(op)
167+
if isu₁unrolled(op)
162168
for u 0:u₁-1
163169
push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
164170
end
@@ -173,12 +179,11 @@ function lower_load!(
173179
umin = 0
174180
end
175181
else
176-
maybegesp_call!(q, op, td)
177182
umin = 0
178183
end
179-
if vectorized loopdependencies(op)
184+
if isvectorized(op)
180185
lower_load_vectorized!(q, ls, op, td, mask, umin)
181186
else
182-
lower_load_scalar!(q, op, td, umin)
187+
lower_load_scalar!(q, op, td, umin, indices_calculated_by_pointer_offsets(ls, op.ref))
183188
end
184189
end

0 commit comments

Comments
 (0)