Skip to content

Commit 718a6e7

Browse files
committed
When unrolling, calculate initial offset pointer, and then calculate remaining loads with respect to that.
1 parent 3e37f91 commit 718a6e7

19 files changed

+334
-218
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ DocStringExtensions = "0.8"
1717
OffsetArrays = "1"
1818
SIMDPirates = "0.7.24"
1919
SLEEFPirates = "0.4.8"
20-
UnPack = "0"
20+
UnPack = "0,1"
2121
VectorizationBase = "0.11.3"
2222
julia = "1.1"
2323

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, UnPack, OffsetArrays
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
6-
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero, stridedpointer_for_broadcast,
6+
maybestaticlength, maybestaticsize, staticm1, staticp1, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
88
AbstractColumnMajorStridedPointer, AbstractRowMajorStridedPointer, AbstractSparseStridedPointer, AbstractStaticStridedPointer,
99
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct,
10-
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater, noalias!
10+
maybestaticfirst, maybestaticlast, scalar_less, scalar_greater, noalias!, gesp
1111
using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange,
1212
reduced_add, reduced_prod, reduce_to_add, reduced_max, reduced_min, vsum, vprod, vmaximum, vminimum,
1313
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,

src/add_constants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementby
3636
temp = gensym(:intermediateconstref)
3737
vloadcall = Expr(:call, lv(:vload), mpref.mref.ptr)
3838
if length(getindices(op)) > 0
39-
push!(vloadcall.args, mem_offset(op, UnrollArgs(0, Symbol(""), Symbol(""), Symbol(""), nothing)))
39+
push!(vloadcall.args, mem_offset(op, UnrollArgs(0, Symbol(""), Symbol(""), Symbol(""), nothing), false, false))
4040
end
4141
pushpreamble!(ls, Expr(:(=), temp, vloadcall))
4242
pushpreamble!(ls, op, temp)

src/condense_loopset.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ function add_external_functions!(q::Expr, ls::LoopSet)
195195
end
196196
end
197197

198+
function check_if_empty(ls::LoopSet, q::Expr)
199+
lb = loop_boundaries(ls)
200+
Expr(:if, Expr(:call, :!, Expr(:call, :any, :isempty, lb)), q)
201+
end
202+
198203
# Try to condense in type stable manner
199204
function generate_call(ls::LoopSet, inline_unroll::NTuple{3,Int8}, debug::Bool = false)
200205
operation_descriptions = Expr(:curly, :Tuple)
@@ -247,6 +252,8 @@ Additionally, define `pointer` and `stride` methods.
247252
@inline check_args(A::PermutedDimsArray) = check_args(parent(A))
248253
@inline check_args(A::StridedArray) = check_type(eltype(A))
249254
@inline check_args(A::AbstractRange) = check_type(eltype(A))
255+
@inline check_args(A::BitVector) = true
256+
@inline check_args(A::BitMatrix) = true
250257
@inline function check_args(A::AbstractArray)
251258
M = parentmodule(typeof(A))
252259
if parent(A) === A # SparseMatrix, StaticArray, etc
@@ -305,15 +312,16 @@ end
305312
function setup_call_debug(ls::LoopSet)
306313
# avx_loopset(instr, ops, arf, AM, LB, vargs)
307314
pushpreamble!(ls, generate_call(ls, (zero(Int8),zero(Int8),zero(Int8)), true))
308-
ls.preamble
315+
Expr(:block, ls.prepreamble, ls.preamble)
309316
end
310-
function setup_call(ls::LoopSet, q = nothing, inline::Int8 = zero(Int8), u₁::Int8 = zero(Int8), u₂::Int8 = zero(Int8))
317+
function setup_call(ls::LoopSet, q = nothing, inline::Int8 = zero(Int8), check_empty::Bool = false, u₁::Int8 = zero(Int8), u₂::Int8 = zero(Int8))
311318
# We outline/inline at the macro level by creating/not creating an anonymous function.
312319
# The old API instead was based on inlining or not inline the generated function, but
313320
# the generated function must be inlined into the initial loop preamble for performance reasons.
314321
# Creating an anonymous function and calling it also achieves the outlining, while still
315322
# inlining the generated function into the loop preamble.
316323
call = setup_call_inline(ls, inline, u₁, u₂)
324+
call = check_empty ? check_if_empty(ls, call) : call
317325
isnothing(q) && return Expr(:block, ls.prepreamble, call)
318326
Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), call, make_fast_and_crashy(q)))
319327
end

src/constructors.jl

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ where `inline=false` is faster, so the best setting may require experimentation.
112112
tries to guess. Currently the algorithm is simple: roughly, if there are more than two dynamically sized loops
113113
or and no convolutions, it will probably not force inlining. Otherwise, it probably will.
114114
115+
`check_empty` (default is `false`) determines whether or not it will check if any of the iterators are empty.
116+
If false, you must ensure yourself that they are not empty, else the behavior of the loop is undefined and
117+
(like with `@inbounds`) segmentation faults are likely.
118+
115119
`unroll` is an integer that specifies the loop unrolling factor, or a
116120
tuple `(u₁, u₂) = (4, 2)` signaling that the generated code should unroll more than
117121
one loop. `u₁` is the unrolling factor for the first unrolled loop and `u₂` for the next (if present),
@@ -165,30 +169,46 @@ function check_unroll(arg)
165169
end
166170
u₁, u₂
167171
end
168-
function check_macro_kwarg(arg, inline::Int8 = zero(Int8), u₁::Int8 = zero(Int8), u₂::Int8 = zero(Int8))
172+
function check_checkempty(arg)
173+
arg.args[1] === :check_empty ? (arg.args[2])::Bool : nothing
174+
end
175+
function check_macro_kwarg(arg, inline::Int8 = zero(Int8), check_empty::Bool = false, u₁::Int8 = zero(Int8), u₂::Int8 = zero(Int8))
169176
@assert arg.head === :(=)
170177
i = check_inline(arg)
171178
if iszero(i)
172-
u₁, u₂ = check_unroll(arg)
179+
ce = check_checkempty(arg)
180+
if isnothing(ce)
181+
u₁, u₂ = check_unroll(arg)
182+
else
183+
check_empty = ce
184+
end
173185
else
174186
inline = i
175187
end
176-
inline, u₁, u₂
188+
inline, check_empty, u₁, u₂
177189
end
178190
macro avx(arg, q)
179191
@assert q.head === :for
180192
@assert arg.head === :(=)
181193
q = macroexpand(__module__, q)
182-
inline, u₁, u₂ = check_macro_kwarg(arg)
194+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg)
183195
ls = LoopSet(q, __module__)
184-
esc(setup_call(ls, q, inline, u₁, u₂))
196+
esc(setup_call(ls, q, inline, check_empty, u₁, u₂))
185197
end
186198
macro avx(arg1, arg2, q)
187199
@assert q.head === :for
188200
q = macroexpand(__module__, q)
189-
inline, u₁, u₂ = check_macro_kwarg(arg1)
190-
inline, u₁, u₂ = check_macro_kwarg(arg2, inline, u₁, u₂)
191-
esc(setup_call(LoopSet(q, __module__), q, inline, u₁, u₂))
201+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg1)
202+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg2, inline, check_empty, u₁, u₂)
203+
esc(setup_call(LoopSet(q, __module__), q, inline, check_empty, u₁, u₂))
204+
end
205+
macro avx(arg1, arg2, arg3, q)
206+
@assert q.head === :for
207+
q = macroexpand(__module__, q)
208+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg1)
209+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg2, inline, check_empty, u₁, u₂)
210+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg3, inline, check_empty, u₁, u₂)
211+
esc(setup_call(LoopSet(q, __module__), q, inline, check_empty, u₁, u₂))
192212
end
193213

194214

@@ -197,20 +217,25 @@ end
197217
198218
This macro transforms loops similarly to [`@avx`](@ref).
199219
While `@avx` punts to a generated function to enable type-based analysis, `_@avx`
200-
works on just the expressions. This requires that it makes a number of default assumptions.
220+
works on just the expressions. This requires that it makes a number of default assumptions. Use of `@avx` is preferred.
221+
222+
This macro accepts the `inline` and `unroll` keyword arguments like `@avx`, but ignores the `check_empty` argument.
201223
"""
202224
macro _avx(q)
203225
q = macroexpand(__module__, q)
204-
esc(lower_and_split_loops(LoopSet(q, __module__), -1))
226+
ls = LoopSet(q, __module__)
227+
esc(Expr(:block, ls.prepreamble, lower_and_split_loops(ls, -1)))
205228
end
206229
macro _avx(arg, q)
207230
@assert q.head === :for
208231
q = macroexpand(__module__, q)
209-
inline, u₁, u₂ = check_macro_kwarg(arg)
210-
esc(lower(LoopSet(q, __module__), u₁ % Int, u₂ % Int, -1))
232+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg)
233+
ls = LoopSet(q, __module__)
234+
esc(Expr(:block, ls.prepreamble, lower(ls, u₁ % Int, u₂ % Int, -1)))
211235
end
212236

213237
macro avx_debug(q)
214238
q = macroexpand(__module__, q)
215-
esc(LoopVectorization.setup_call_debug(LoopSet(q, __module__)))
239+
ls = LoopSet(q, __module__)
240+
esc(LoopVectorization.setup_call_debug(ls))
216241
end

src/determinestrategy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
767767
cost_temp = evaluate_cost_unroll(ls, new_order, new_vec, lowest_cost)
768768
if cost_temp < lowest_cost
769769
lowest_cost = cost_temp
770-
best_order = new_order
770+
copyto!(best_order, new_order)
771771
best_vec = new_vec
772772
end
773773
end

src/graphs.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ Base.length(loop::Loop) = 1 + loop.stophint - loop.starthint
7171
isstaticloop(loop::Loop) = loop.startexact & loop.stopexact
7272
function startloop(loop::Loop, isvectorized, itersymbol)
7373
startexact = loop.startexact
74-
if isvectorized
75-
if startexact
76-
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.starthint))
77-
else
78-
Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
79-
end
80-
elseif startexact
74+
# if isvectorized
75+
# if startexact
76+
# Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.starthint))
77+
# else
78+
# Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
79+
# end
80+
if startexact
8181
Expr(:(=), itersymbol, loop.starthint)
8282
else
8383
Expr(:(=), itersymbol, Expr(:call, lv(:unwrap), loop.startsym))
@@ -91,17 +91,17 @@ function vec_looprange(loop::Loop, UF::Int, mangledname::Symbol)
9191
Expr(:call, lv(:valsub), VECTORWIDTHSYMBOL, 2)
9292
end
9393
if loop.stopexact # split for type stability
94-
Expr(:call, lv(:scalar_less), mangledname, Expr(:call, :-, loop.stophint, incr))
94+
Expr(:call, :<, mangledname, Expr(:call, lv(:vsub), loop.stophint, incr))
9595
else
96-
Expr(:call, lv(:scalar_less), mangledname, Expr(:call, :-, loop.stopsym, incr))
96+
Expr(:call, :<, mangledname, Expr(:call, lv(:vsub), loop.stopsym, incr))
9797
end
9898
end
9999
function looprange(loop::Loop, incr::Int, mangledname::Symbol)
100100
incr = 2 - incr
101101
if iszero(incr)
102-
Expr(:call, lv(:scalar_less), mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
102+
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint : loop.stopsym)
103103
else
104-
Expr(:call, lv(:scalar_less), mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, :+, loop.stopsym, incr))
104+
Expr(:call, :<, mangledname, loop.stopexact ? loop.stophint + incr : Expr(:call, lv(:vadd), loop.stopsym, incr))
105105
end
106106
end
107107
function terminatecondition(
@@ -119,11 +119,12 @@ function incrementloopcounter(us::UnrollSpecification, n::Int, mangledname::Symb
119119
if isvectorized(us, n)
120120
if UF == 1
121121
Expr(:(=), mangledname, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname))
122+
# Expr(:(=), mangledname, Expr(:macrocall, Symbol("@show"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname)))
122123
else
123-
Expr(:+=, mangledname, Expr(:call, lv(:valmul), VECTORWIDTHSYMBOL, UF))
124+
Expr(:(=), mangledname, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, UF, mangledname))
124125
end
125126
else
126-
Expr(:+=, mangledname, UF)
127+
Expr(:(=), mangledname, Expr(:call, lv(:vadd), mangledname, UF))
127128
end
128129
end
129130

@@ -367,7 +368,7 @@ function add_loop_bound!(ls::LoopSet, itersym::Symbol, bound, upper::Bool = true
367368
(bound isa Symbol && upper) && return bound
368369
bound isa Expr && maybestatic!(bound)
369370
N = gensym(string(itersym) * (upper ? "_loop_upper_bound" : "_loop_lower_bound"))
370-
pushpreamble!(ls, Expr(:(=), N, bound))
371+
pushprepreamble!(ls, Expr(:(=), N, bound))
371372
N
372373
end
373374

@@ -410,20 +411,20 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
410411
else
411412
otN isa Expr && maybestatic!(otN)
412413
N = gensym("loop" * string(itersym))
413-
pushpreamble!(ls, Expr(:(=), N, otN))
414+
pushprepreamble!(ls, Expr(:(=), N, otN))
414415
Loop(itersym, 1, N)
415416
end
416417
else
417418
N = gensym("loop" * string(itersym))
418-
pushpreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
419+
pushprepreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
419420
L = add_loop_bound!(ls, itersym, Expr(:call, lv(:maybestaticfirst), N), false)
420421
U = add_loop_bound!(ls, itersym, Expr(:call, lv(:maybestaticlast), N), true)
421422
Loop(itersym, L, U)
422423
end
423424
elseif isa(r, Symbol)
424425
# Treat similar to `eachindex`
425426
N = gensym("loop" * string(itersym))
426-
pushpreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
427+
pushprepreamble!(ls, Expr(:(=), N, Expr(:call, lv(:maybestaticrange), r)))
427428
L = add_loop_bound!(ls, itersym, Expr(:call, lv(:maybestaticfirst), N), false)
428429
U = add_loop_bound!(ls, itersym, Expr(:call, lv(:maybestaticlast), N), true)
429430
loop = Loop(itersym, L, U)

src/lower_compute.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ end
7171
function add_loopvalue!(instrcall::Expr, loopval::Symbol, vectorized::Symbol, u::Int)
7272
if loopval === vectorized
7373
if isone(u)
74-
push!(instrcall.args, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, loopval))
74+
push!(instrcall.args, Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, _MMind(loopval)))
7575
else
76-
push!(instrcall.args, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, u, loopval))
76+
push!(instrcall.args, Expr(:call, lv(:valmuladd), VECTORWIDTHSYMBOL, u, _MMind(loopval)))
7777
end
7878
else
7979
push!(instrcall.args, Expr(:call, :+, loopval, u))
@@ -85,6 +85,8 @@ function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u::Int)
8585
add_loopvalue!(instrcall, loopval, vectorized, u)
8686
elseif !isnothing(suffix) && suffix > 0 && loopval === u₂loopsym
8787
add_loopvalue!(instrcall, loopval, vectorized, suffix)
88+
elseif loopval === vectorized
89+
push!(instrcall.args, _MMind(loopval))
8890
else
8991
push!(instrcall.args, loopval)
9092
end

src/lower_load.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ function lower_load_scalar!(
66
@assert vectorized loopdeps
77
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loop, u₂loop, suffix)
88
mvar = variable_name(op, suffix)
9-
ptr = refname(op)
109
opu₁ = u₁loopsym loopdeps
10+
unrolled = (opu₁ || u₂loopsym loopdeps)
11+
ptr = unrolled ? offset_refname(op) : refname(op)
1112
U = opu₁ ? u₁ : 1
1213
if instruction(op).instr !== :conditionalload
1314
for u umin:U-1
1415
varname = varassignname(mvar, u, opu₁)
1516
td = UnrollArgs(ua, u)
16-
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
17+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, unrolled))))
1718
end
1819
else
1920
opu₂ = !isnothing(suffix) && u₂loopsym loopdeps
@@ -23,20 +24,20 @@ function lower_load_scalar!(
2324
condsym = varassignname(condvar, u, condu₁)
2425
varname = varassignname(mvar, u, u₁loopsym loopdependencies(op))
2526
td = UnrollArgs(ua, u)
26-
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))
27+
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td, unrolled))
2728
cload = Expr(:if, condsym, load, Expr(:call, :zero, Expr(:call, :eltype, ptr)))
2829
push!(q.args, Expr(:(=), varname, cload))
2930
end
3031
end
3132
nothing
3233
end
3334
function pushvectorload!(
34-
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, vectorized::Symbol, mask
35+
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, vectorized::Symbol, mask, u₁unrolled::Bool, unrolled::Bool
3536
)
3637
@unpack u₁, u₁loopsym, u₂loopsym, suffix = td
37-
ptr = refname(op)
38+
ptr = unrolled ? offset_refname(op) : refname(op)
3839
vecnotunrolled = vectorized !== u₁loopsym
39-
name, mo = name_memoffset(var, op, td)
40+
name, mo = name_memoffset(var, op, td, u₁unrolled, unrolled)
4041
instrcall = Expr(:call, lv(:vload), ptr, mo)
4142

4243
iscondstore = instruction(op).instr === :conditionalload
@@ -96,7 +97,9 @@ function lower_load_vectorized!(
9697
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = td
9798
loopdeps = loopdependencies(op)
9899
@assert vectorized loopdeps
99-
if u₁loopsym loopdeps
100+
opu₁ = u₁loopsym loopdeps
101+
unrolled = opu₁ || u₂loopsym loopdeps
102+
if opu₁
100103
umin = umin
101104
U = u₁
102105
else
@@ -107,25 +110,24 @@ function lower_load_vectorized!(
107110
var = variable_name(op, suffix)
108111
for u umin:U-1
109112
td = UnrollArgs(td, u)
110-
pushvectorload!(q, op, var, td, U, vectorized, mask)
113+
pushvectorload!(q, op, var, td, U, vectorized, mask, opu₁, unrolled)
111114
end
112115
prefetchind = prefetchisagoodidea(ls, op, td)
113116
if !iszero(prefetchind)
114117
dontskip = (64 ÷ VectorizationBase.REGISTER_SIZE) - 1
115-
ptr = refname(op)
118+
ptr = offset_refname(op)
116119
innermostloopsym = last(ls.loop_order.bestorder)
117120
us = ls.unrollspecification[]
118121
prefetch_multiplier = 4
119122
prefetch_distance = u₁loopsym === innermostloopsym ? us.u₁ : ( u₂loopsym === innermostloopsym ? us.u₂ : 1 )
120123
prefetch_distance *= prefetch_multiplier
121124
offsets = op.ref.ref.offsets
122125
inner_offset = offsets[prefetchind]
123-
124126
for u umin:U-1
125127
# for u ∈ umin:min(umin,U-1)
126128
(u₁loopsym === vectorized && !iszero(u & dontskip)) && continue
127129
offsets[prefetchind] = inner_offset + prefetch_distance
128-
mo = last(name_memoffset(var, op, UnrollArgs(td, u)))
130+
mo = mem_offset_u(op, UnrollArgs(td, u), true)
129131
instrcall = Expr(:call, lv(:prefetch0), ptr, mo)
130132
push!(q.args, instrcall)
131133
end
@@ -170,6 +172,7 @@ function lower_load!(
170172
umin = 0
171173
end
172174
else
175+
maybegesp_call!(q, op, td)
173176
umin = 0
174177
end
175178
if vectorized loopdependencies(op)

0 commit comments

Comments
 (0)