Skip to content

Commit 3bbfcfb

Browse files
committed
LoopVectorization tests pass.
1 parent 5f8f32b commit 3bbfcfb

13 files changed

+133
-105
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LoopVectorization
22

33
using VectorizationBase, SIMDPirates, SLEEFPirates, UnPack, OffsetArrays
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
5-
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
5+
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valmulsub, valadd, valsub, _MM,
66
maybestaticlength, maybestaticsize, staticm1, staticp1, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,

src/determinestrategy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ function determine_unroll_factor(
233233
load_recip_throughput,
234234
store_recip_throughput
235235
)
236-
roundpow2(max(1, round(Int, latency / (recip_throughput * num_reductions) ) ))
236+
min(8, roundpow2(max(1, round(Int, latency / (recip_throughput * num_reductions) ) )))
237237
end
238238

239239
function unroll_cost(X, u₁, u₂, u₁L, u₂L)

src/graphs.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,51 @@ function startloop(loop::Loop, isvectorized, itersymbol)
7878
# Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
7979
# end
8080
if startexact
81-
Expr(:(=), itersymbol, loop.starthint)
81+
Expr(:(=), itersymbol, loop.starthint - 1)
8282
else
83-
Expr(:(=), itersymbol, Expr(:call, lv(:unwrap), loop.startsym))
83+
Expr(:(=), itersymbol, Expr(:call, lv(:staticm1), Expr(:call, lv(:unwrap), loop.startsym)))
8484
end
8585
end
86+
addexpr(ex, incr) = Expr(:call, lv(:vadd), ex, incr)
87+
function addexpr(ex, incr::Number)
88+
if iszero(incr)
89+
incr
90+
elseif incr > 0
91+
Expr(:call, lv(:vadd), ex, incr)
92+
else
93+
Expr(:call, lv(:vsub), ex, -incr)
94+
end
95+
end
96+
addexpr(ex::Number, incr::Number) = ex + incr
97+
subexpr(ex, incr) = Expr(:call, lv(:vsub), ex, incr)
98+
subexpr(ex::Number, incr::Number) = ex - incr
99+
subexpr(ex, incr::Number) = addexpr(ex, -incr)
86100
function vec_looprange(loop::Loop, UF::Int, mangledname::Symbol)
87-
isunrolled = UF > 1
88-
incr = if isunrolled
89-
Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, UF, -2)
101+
incr = if isone(UF)
102+
Expr(:call, lv(:valsub), VECTORWIDTHSYMBOL, 1)
90103
else
91-
Expr(:call, lv(:valsub), VECTORWIDTHSYMBOL, 2)
104+
Expr(:call, lv(:valmulsub), VECTORWIDTHSYMBOL, UF, 1)
92105
end
93106
if loop.stopexact # split for type stability
94-
Expr(:call, :<, mangledname, Expr(:call, lv(:vsub), loop.stophint, incr))
107+
Expr(:call, :<, mangledname, subexpr(loop.stophint, incr))
95108
else
96-
Expr(:call, :<, mangledname, Expr(:call, lv(:vsub), loop.stopsym, incr))
109+
Expr(:call, :<, mangledname, subexpr(loop.stopsym, incr))
97110
end
98111
end
99-
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
100-
incr = 2 - incr
112+
113+
function looprange(stopcon, incr::Int, mangledname::Symbol)
114+
incr = 1 - incr
101115
if iszero(incr)
102-
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
116+
Expr(:call, :<, mangledname, stopcon)
117+
elseif isone(incr)
118+
Expr(:call, :, mangledname, stopcon)
103119
else
104-
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, lv(:vadd), loop.stopsym, incr))
120+
Expr(:call, :<, mangledname, addexpr(stopcon, incr))
105121
end
106122
end
123+
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
124+
loop.stopexact ? looprange(loop.stophint, incr, mangledname) : looprange(loop.stopsym, incr, mangledname)
125+
end
107126
function terminatecondition(
108127
loop::Loop, us::UnrollSpecification, n::Int, mangledname::Symbol, inclmask::Bool, UF::Int = unrollfactor(us, n)
109128
)

src/lower_compute.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ end
7171
function add_loopvalue!(instrcall::Expr, loopval::Symbol, vectorized::Symbol, u::Int)
7272
if loopval === vectorized
7373
if isone(u)
74-
push!(instrcall.args, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, _MMind(loopval)))
74+
push!(instrcall.args, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, _MMind(Expr(:call, lv(:staticp1), loopval))))
7575
else
76-
push!(instrcall.args, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, u, _MMind(loopval)))
76+
push!(instrcall.args, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, u, _MMind(Expr(:call, lv(:staticp1), loopval))))
7777
end
7878
else
79-
push!(instrcall.args, Expr(:call, :+, loopval, u))
79+
push!(instrcall.args, Expr(:call, lv(:vadd), loopval, u + 1))
8080
end
8181
end
8282
function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u::Int)
@@ -86,15 +86,14 @@ function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u::Int)
8686
elseif !isnothing(suffix) && suffix > 0 && loopval === u₂loopsym
8787
add_loopvalue!(instrcall, loopval, vectorized, suffix)
8888
elseif loopval === vectorized
89-
push!(instrcall.args, _MMind(loopval))
89+
push!(instrcall.args, _MMind(Expr(:call, lv(:staticp1), loopval)))
9090
else
91-
push!(instrcall.args, loopval)
91+
push!(instrcall.args, Expr(:call, lv(:staticp1), loopval))
9292
end
9393
end
9494

9595
function lower_compute!(
9696
q::Expr, op::Operation, ua::UnrollArgs, mask::Union{Nothing,Symbol,Unsigned} = nothing,
97-
opunrolled = ua.u₁loopsym loopdependencies(op)
9897
)
9998
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = ua
10099
var = name(op)
@@ -159,7 +158,7 @@ function lower_compute!(
159158
# diffdeps = !any(opp -> isload(opp) && all(in(loopdependencies(opp)), loopdependencies(op)), parents(op)) # want to instcombine when parent load's deps are superset
160159
# @show suffix, !isnothing(suffix), isreduct, diffdeps
161160
# end
162-
if !isnothing(suffix) && isreduct
161+
if !isnothing(suffix) && isreduct# && (iszero(suffix) || (ls.unrollspecification[].u₂ - 1 == suffix))
163162
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
164163
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
165164
# want to instcombine when parent load's deps are superset

src/lower_load.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ function pushvectorload!(
6363
push!(instrcall.args, mask)
6464
end
6565
push!(q.args, Expr(:(=), name, instrcall))
66+
# push!(q.args, :(@show $name))
6667
end
6768
function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
6869
# return false

src/lower_memory_common.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
1515
else
1616
mangledvar(parent)
1717
end
18-
u₁loopsym loopdependencies(parent) ? Symbol(pvar, u₁) : pvar
18+
ex = u₁loopsym loopdependencies(parent) ? Symbol(pvar, u₁) : pvar
19+
Expr(:call, lv(:staticm1), ex)
1920
end
2021

2122

2223
_MMind(ind) = Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, ind)
24+
_MMind(ind::Integer) = Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, convert(Int, ind))
2325
function addoffset!(ret::Expr, ex, offset::Integer, _mm::Bool = false)
2426
if iszero(offset)
2527
if _mm
@@ -147,7 +149,7 @@ function mem_offset_u(op::Operation, td::UnrollArgs, unrolled::Bool)
147149
# append_inds!(ret, indices, loopedindex)
148150
else
149151
for (n,ind) enumerate(indices)
150-
offset = offsets[n]
152+
offset = convert(Int, offsets[n])
151153
# if ind isa Int # impossible
152154
# push!(ret.args, ind + offset)
153155
# else
@@ -182,15 +184,6 @@ function mem_offset_u(op::Operation, td::UnrollArgs, unrolled::Bool)
182184
ret
183185
end
184186

185-
# function add_expr(q, incr)
186-
# if q.head === :call && q.args[2] === :+
187-
# qc = copy(q)
188-
# push!(qc.args, incr)
189-
# qc
190-
# else
191-
# Expr(:call, :+, q, incr)
192-
# end
193-
# end
194187
function varassignname(var::Symbol, u::Int, isunrolled::Bool)
195188
isunrolled ? Symbol(var, u) : var
196189
end

src/lower_store.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ end
6363

6464
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
6565
U == 1 && return nothing
66+
@assert U > 1 "U = $U somehow < 1"
6667
instr = Instruction(reduction_to_single_vector(instr))
6768
Uh2 = U
6869
iter = 0

src/lowering.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ function assume(ex)
188188
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:assume)), ex)
189189
end
190190
function expect(ex)
191-
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:expect)), ex)
191+
# Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:expect)), ex)
192+
ex
192193
end
193194
function loopiteratesatleastonce(loop::Loop, as::Bool = true)
194195
comp = if loop.startexact # requires !loop.stopexact
@@ -209,31 +210,29 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
209210
# if VERSION ≥ v"1.4" && !nisvectorized && !inclmask && isone(n) && !ls.loadelimination[] && (us.u₁ > 1) && (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && length(loop) > 7
210211
# return lower_llvm_unroll(ls, us, n, loop)
211212
# end
212-
213-
214213
sl = startloop(loop, nisvectorized, loopsym)
215214
tc = terminatecondition(loop, us, n, loopsym, inclmask, 1)
216215
body = lower_block(ls, us, n, inclmask, 1)
217216
q = if (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !isstaticloop(loop) && !inclmask# && !ls.loadelimination[]
218217
# Expr(:block, sl, assumeloopiteratesatleastonce(loop), Expr(:while, tc, body))
219218
if nisvectorized
220-
Expr(:block, sl, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
219+
Expr(:block, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
221220
else
222221
# Expr(:block, sl, assume(tc), Expr(:while, tc, body))
223222
push!(body.args, Expr(:||, expect(tc), Expr(:break)))
224223
# Expr(:block, sl, assume(tc), Expr(:while, true, body))
225-
Expr(:block, sl, Expr(:while, true, body))
224+
Expr(:block, Expr(:while, true, body))
226225
end
227226
else
228-
Expr(:block, sl, Expr(:while, tc, body))
227+
Expr(:block, Expr(:while, tc, body))
229228
end
230229

231230
if nisvectorized
232231
tc = terminatecondition(loop, us, n, loopsym, true, 1)
233232
body = lower_block(ls, us, n, true, 1)
234233
push!(q.args, Expr(:if, tc, body))
235234
end
236-
q
235+
Expr(:block, Expr(:let, sl, q))
237236
end
238237
function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool)
239238
UF = unrollfactor(us, n)
@@ -254,12 +253,12 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
254253

255254
remfirst = loopisstatic & !(unsigned(Ureduct) < unsigned(UF))
256255
if remfirst
257-
tc = Expr(:call, lv(:scalar_less), loopsym, loop.stophint + 1)
256+
tc = Expr(:call, :<, loopsym, loop.stophint)
258257
else
259258
tc = terminatecondition(loop, us, n, loopsym, inclmask, UF)
260259
end
261260
usorig = ls.unrollspecification[]
262-
tc = (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !isstaticloop(loop) && !inclmask && !ls.loadelimination[] ? expect(tc) : tc
261+
tc = (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !loopisstatic && !inclmask && !ls.loadelimination[] ? expect(tc) : tc
263262
body = lower_block(ls, us, n, inclmask, UF)
264263
q = Expr(:while, tc, body)
265264
remblock = init_remblock(loop, loopsym)
@@ -272,7 +271,7 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
272271
UF_cleanup = UF - Ureduct
273272
us_cleanup = nisunrolled ? UnrollSpecification(us, UF_cleanup, u₂) : UnrollSpecification(us, u₁, UF_cleanup)
274273
Expr(
275-
:block, sl,
274+
:block,
276275
add_upper_outer_reductions(ls, q, Ureduct, UF, loop, vectorized),
277276
Expr(
278277
:if, terminatecondition(loop, us, n, loopsym, inclmask, UF_cleanup),
@@ -283,9 +282,9 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
283282
elseif remfirst
284283
numiters = length(loop) ÷ UF
285284
if numiters > 2
286-
Expr( :block, sl, remblock, q )
285+
Expr( :block, remblock, q )
287286
else
288-
q = Expr(:block, sl, remblock)
287+
q = Expr(:block, remblock)
289288
for i 1:numiters
290289
push!(q.args, body)
291290
end
@@ -298,7 +297,7 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
298297
# else
299298
# Expr(:block, sl, q, remblock)
300299
# end
301-
Expr( :block, sl, q, remblock )
300+
Expr( :block, q, remblock )
302301
end
303302
UFt = if loopisstatic
304303
length(loop) % UF
@@ -308,17 +307,17 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
308307
while !iszero(UFt)
309308
comparison = if nisvectorized
310309
itercount = if loop.stopexact
311-
Expr(:call, lv(:vsub), loop.stophint, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UFt))
310+
Expr(:call, lv(:vsub), loop.stophint - 1, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UFt))
312311
else
313-
Expr(:call, lv(:vsub), loop.stopsym, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UFt))
312+
Expr(:call, lv(:vsub), loop.stopsym, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, UFt, 1))
314313
end
315-
Expr(:call, lv(:scalar_greater), loopsym, itercount)
314+
Expr(:call, :>, loopsym, itercount)
316315
elseif remfirst
317-
Expr(:call, lv(:scalar_less), loopsym, loop.starthint + UFt)
316+
Expr(:call, :<, loopsym, loop.starthint + UFt - 1)
318317
elseif loop.stopexact
319-
Expr(:call, lv(:scalar_greater), loopsym, loop.stophint - UFt)
318+
Expr(:call, :>, loopsym, loop.stophint - UFt - 1)
320319
else
321-
Expr(:call, lv(:scalar_greater), loopsym, Expr(:call, lv(:vsub), loop.stopsym, UFt))
320+
Expr(:call, :>, loopsym, Expr(:call, lv(:vsub), loop.stopsym, UFt + 1))
322321
end
323322
ust = nisunrolled ? UnrollSpecification(us, UFt, u₂) : UnrollSpecification(us, u₁, UFt)
324323
remblocknew = Expr(:elseif, comparison, lower_block(ls, ust, n, remmask, UFt))
@@ -330,7 +329,7 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
330329
UFt += 1
331330
end
332331
end
333-
q
332+
Expr(:block, Expr(:let, sl, q))
334333
end
335334

336335
function initialize_outer_reductions!(
@@ -432,9 +431,9 @@ function determine_width(
432431
end
433432
function init_remblock(unrolledloop::Loop, u₁loop::Symbol = unrolledloop.itersymbol)
434433
condition = if unrolledloop.stopexact
435-
Expr(:call, lv(:scalar_greater), u₁loop, unrolledloop.stophint)
434+
Expr(:call, :>, u₁loop, unrolledloop.stophint - 1)
436435
else
437-
Expr(:call, lv(:scalar_greater), u₁loop, unrolledloop.stopsym)
436+
Expr(:call, :, u₁loop, unrolledloop.stopsym)
438437
end
439438
Expr(:if, condition, nothing)
440439
end

src/reconstruct_loopset.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,22 @@ function pushvarg′!(ls::LoopSet, ar::ArrayReferenceMeta, i, name)
124124
reverse!(ar.loopedindex); reverse!(getindices(ar)) # reverse the listed indices here, and transpose it to make it column major
125125
pushpreamble!(ls, Expr(:(=), name, Expr(:call, lv(:transpose), extract_varg(i))))
126126
end
127+
function assume_strides!(ls, name, N)
128+
for n 1:N
129+
pushpreamble!(ls, assume( Expr(:call, :>, Expr(:ref, Expr(:(.), name, QuoteNode(:strides)), n), 0) ))
130+
end
131+
end
127132
function add_mref!(
128133
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}, name = vptr(ar)
129134
) where {T, N, S <: AbstractColumnMajorStridedPointer{T,N}}
130135
pushvarg!(ls, ar, i, name)
136+
assume_strides!(ls, name, N)
131137
end
132138
function add_mref!(
133139
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}, name = vptr(ar)
134140
) where {T, N, S <: AbstractRowMajorStridedPointer{T, N}}
135141
pushvarg′!(ls, ar, i, name)
142+
assume_strides!(ls, name, N)
136143
end
137144
function add_mref!(
138145
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{S}, name = vptr(ar)
@@ -150,6 +157,7 @@ function add_mref!(
150157
li[i] = lib[S1[i]]
151158
inds[i] = indsb[S1[i]]
152159
end
160+
assume_strides!(ls, name, length(S1))
153161
end
154162
function add_mref!(
155163
ls::LoopSet, ar::ArrayReferenceMeta, i::Int, ::Type{OffsetStridedPointer{T,N,P}}, name = vptr(ar)

src/vectorizationbase_extensions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
# if ndim(A::OffsetArray) ≥ 2, then eachindex(A) isa Base.OneTo, index starting at 1.
1212
# but multiple indexing is calculated using offsets, so we need a special type to express this.
1313
@inline function VectorizationBase.stridedpointer(A::OffsetArrays.OffsetArray)
14-
OffsetStridedPointer(stridedpointer(parent(A)), VectorizationBase.staticm1(A.offsets))
14+
OffsetStridedPointer(stridedpointer(parent(A)), A.offsets)
1515
end
1616

1717
@inline function VectorizationBase.stridedpointer(
@@ -20,7 +20,7 @@ end
2020
Boff = parent(B)
2121
OffsetStridedPointer(
2222
stridedpointer(parent(Boff)'),
23-
VectorizationBase.staticm1(Boff.offsets)
23+
Boff.offsets
2424
)
2525
end
2626
@inline function Base.transpose(A::OffsetStridedPointer)
@@ -37,7 +37,7 @@ end
3737
@inline VectorizationBase.offset(ptr::OffsetStridedPointer{<:Any,N}, ind::Tuple) where {N} = ntuple(n -> vsub(ind[n], ptr.offsets[n]), Val{N}())
3838
@inline Base.similar(p::OffsetStridedPointer, ptr::Ptr) = OffsetStridedPointer(similar(p.ptr, ptr), p.offsets)
3939
@inline Base.pointer(p::OffsetStridedPointer) = pointer(p.ptr)
40-
@inline VectorizationBase.gesp(p::OffsetStridedPointer, i) = similar(p.ptr, gep(p, staticm1(i)))
40+
@inline VectorizationBase.gesp(p::OffsetStridedPointer, i) = similar(p.ptr, gep(p, i))
4141
# @inline VectorizationBase.gesp(p::OffsetStridedPointer, i) = similar(p, gep(p.ptr, i))
4242
# If an OffsetArray is getting indexed by a (loop-)constant value, then this particular vptr object cannot also be eachindexed, so we can safely return a stridedpointer
4343
@inline function VectorizationBase.subsetview(ptr::OffsetStridedPointer{<:Any,N}, ::Val{I}, i) where {I,N}

0 commit comments

Comments
 (0)