Skip to content

Commit 3b7344c

Browse files
authored
Interpret some constant expressions, fixes #305 (#306)
* Interpret some constant expressions, fixes #305 * Update register pressure determination
1 parent 95ba84d commit 3b7344c

30 files changed

+824
-580
lines changed

.github/workflows/ci-julia-nightly.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ jobs:
3838
- part3
3939
- part4
4040
- part5
41+
- part6
42+
- part7
43+
- part8
4144
steps:
4245
- uses: actions/checkout@v2
4346
- uses: julia-actions/setup-julia@v1

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ jobs:
3939
- part3
4040
- part4
4141
- part5
42+
- part6
43+
- part7
44+
- part8
4245
steps:
4346
- uses: actions/checkout@v2
4447
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.53"
4+
version = "0.12.54"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -30,5 +30,5 @@ Static = "0.2, 0.3"
3030
StrideArraysCore = "0.1.12"
3131
ThreadingUtilities = "0.4.5"
3232
UnPack = "1"
33-
VectorizationBase = "0.20.23"
33+
VectorizationBase = "0.20.25"
3434
julia = "1.5"

src/LoopVectorization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, cache_si
66
mask, pick_vector_width, MM, AbstractMask, data, grouped_strided_pointer, AbstractSIMD,
77
vzero, offsetprecalc, lazymul,
88
vadd_nw, vadd_nsw, vadd_nuw, vsub_nw, vsub_nsw, vsub_nuw, vmul_nw, vmul_nsw, vmul_nuw,
9+
vfmaddsub, vfmsubadd, vpermilps177, vmovsldup, vmovshdup,
910
maybestaticfirst, maybestaticlast, gep, gesp, NativeTypes, #llvmptr,
1011
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd_fast, vfmsub_fast, vfnmadd_fast, vfnmsub_fast, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231,
1112
vfma_fast, vmuladd_fast, vdiv_fast, vadd_fast, vsub_fast, vmul_fast,

src/codegen/loopstartstopmanager.jl

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -439,49 +439,61 @@ end
439439
# end
440440
# return nothing
441441
# end
442+
443+
442444
function adjust_offsets!(
443445
ls::LoopSet, i::Int,
444446
array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
445447
)
446448
ops = operations(ls)
447-
@assert length(ops) 256
448-
offsets::Base.RefValue{NTuple{256,Int8}} = Base.RefValue{NTuple{256,Int8}}();
449-
GC.@preserve offsets begin
449+
if length(ops) 256
450+
offsets = Ref{NTuple{256,Int8}}()
450451
poffsets = Base.unsafe_convert(Ptr{Int8}, offsets)
451-
minoffset = typemax(Int8)
452-
maxoffset = typemin(Int8)
453-
# stridesunequal = false
452+
GC.@preserve offsets adjust_offsets!(ls, i, poffsets, array_refs_with_same_name, arrayref_to_name_op_collection)
453+
else
454+
offsetsv = similar(ops, Int8)
455+
poffsets = pointer(offsetsv)
456+
GC.@preserve offsetsv adjust_offsets!(ls, i, poffsets, array_refs_with_same_name, arrayref_to_name_op_collection)
457+
end
458+
end
459+
function adjust_offsets!(
460+
ls::LoopSet, i::Int, poffsets::Ptr{Int8},
461+
array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
462+
)
463+
ops = operations(ls)
464+
minoffset = typemax(Int8)
465+
maxoffset = typemin(Int8)
466+
# stridesunequal = false
467+
for j array_refs_with_same_name
468+
arrayref_to_name_op = arrayref_to_name_op_collection[j]
469+
for (_,__,opid) arrayref_to_name_op
470+
opref = ops[opid].ref
471+
off = getoffsets(opref)[i]
472+
minoffset = min(off, minoffset)
473+
maxoffset = max(off, maxoffset)
474+
unsafe_store!(poffsets, off, opid)
475+
# stridesunequal |= (stride ≠ getstrides(opref)[i])
476+
end
477+
end
478+
constoffset = Int(minoffset)
479+
constoffset = Core.ifelse(Int(maxoffset) - constoffset > 127, 0, constoffset)
480+
if constoffset 0
454481
for j array_refs_with_same_name
455482
arrayref_to_name_op = arrayref_to_name_op_collection[j]
456483
for (_,__,opid) arrayref_to_name_op
457484
opref = ops[opid].ref
458-
off = getoffsets(opref)[i]
459-
minoffset = min(off, minoffset)
460-
maxoffset = max(off, maxoffset)
461-
unsafe_store!(poffsets, off, opid)
462-
# stridesunequal |= (stride ≠ getstrides(opref)[i])
463-
end
464-
end
465-
constoffset = Int(minoffset)
466-
constoffset = Core.ifelse(Int(maxoffset) - constoffset > 127, 0, constoffset)
467-
if constoffset 0
468-
for j array_refs_with_same_name
469-
arrayref_to_name_op = arrayref_to_name_op_collection[j]
470-
for (_,__,opid) arrayref_to_name_op
471-
opref = ops[opid].ref
472-
newoffset = unsafe_load(poffsets, opid) - constoffset
473-
# if stridesunequal
474-
# stride = getstrides(opref)[i]
475-
# newoffsetint = Int(newoffset) + (Int(stride) - 1)
476-
# # @assert typemin(Int8) ≤ newoffsetint ≤ typemax(Int8)
477-
# newoffset = Int8(newoffsetint)
478-
# end
479-
getoffsets(ops[opid].ref)[i] = newoffset
480-
end
485+
newoffset = unsafe_load(poffsets, opid) - constoffset
486+
# if stridesunequal
487+
# stride = getstrides(opref)[i]
488+
# newoffsetint = Int(newoffset) + (Int(stride) - 1)
489+
# # @assert typemin(Int8) ≤ newoffsetint ≤ typemax(Int8)
490+
# newoffset = Int8(newoffsetint)
491+
# end
492+
getoffsets(ops[opid].ref)[i] = newoffset
481493
end
482494
end
483495
end
484-
constoffset#, Core.ifelse(stridesunequal, 1, Int(stride))
496+
return constoffset#, Core.ifelse(stridesunequal, 1, Int(stride))
485497
end
486498

487499
function calcgespinds(

src/codegen/lower_constant.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,60 @@ function lower_licm_constants!(ls::LoopSet)
235235
setop!(ls, ops[id], Expr(:call, reduction_zero(f), ELTYPESYMBOL))
236236
end
237237
end
238+
239+
function pushconstvalue!(v::Vector{Any}, ls::LoopSet, op::Operation)::Bool
240+
isconstant(op) || return true
241+
opid = identifier(op)
242+
for (id,(intval,intsz,signed)) ls.preamble_symint
243+
id == opid || continue
244+
if intsz == 1
245+
push!(v, intval % Bool)
246+
return false
247+
elseif intsz == 1
248+
signed ? push!(v, intval % Int8) : push!(v, intval % UInt8)
249+
elseif intsz == 2
250+
signed ? push!(v, intval % Int16) : push!(v, intval % UInt16)
251+
elseif intsz == 4
252+
signed ? push!(v, intval % Int32) : push!(v, intval % UInt32)
253+
else
254+
signed ? push!(v, intval) : push!(v, unsigned(intval))
255+
end
256+
return false
257+
end
258+
for (id,floatval) ls.preamble_symfloat
259+
if id == opid
260+
push!(v, floatval)
261+
return false
262+
end
263+
end
264+
for (id,typ) ls.preamble_zeros
265+
id == opid || continue
266+
if typ == HardFloat
267+
push!(v, 0.0)
268+
else
269+
push!(v, 0)
270+
end
271+
return false
272+
end
273+
for (id,f) ls.preamble_funcofeltypes
274+
id == opid || continue
275+
x = reduction_zero(f)
276+
if x == ADDITIVE_IN_REDUCTIONS
277+
push!(v, 0)
278+
elseif x == MULTIPLICATIVE_IN_REDUCTIONS
279+
push!(v, 1)
280+
elseif x == MAX
281+
push!(v, -Inf)
282+
elseif x == MIN
283+
push!(v, Inf)
284+
elseif x == ALL
285+
push!(v, true)
286+
elseif x == ANY
287+
push!(v, false)
288+
else
289+
return true
290+
end
291+
return false
292+
end
293+
return true
294+
end

src/codegen/lower_store.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function lower_store!(
211211
if reductfunc === Symbol("")
212212
Expr(:call, lv(:_vstore!), sptrsym, gf(mvard,u), inds)
213213
else
214-
Expr(:call, lv(:_vstore!), lv(reductfunc), sptrsym, mvaru, inds)
214+
Expr(:call, lv(:_vstore!), lv(reductfunc), sptrsym, gf(mvard,u), inds)
215215
end
216216
elseif reductfunc === Symbol("")
217217
Expr(:call, lv(:_vstore!), sptrsym, mvar, inds)
@@ -282,13 +282,11 @@ function lower_tiled_store!(blockq::Expr, op::Operation, ls::LoopSet, ua::Unroll
282282
# opp = only(parents(opp))
283283
end
284284
isu₁, isu₂ = isunrolled_sym(opp, u₁loopsym, u₂loopsym, vloopsym, ls)#, u₂)
285-
@assert isu₂
286-
# It's reasonable forthis to be `!isu₁`
285+
# It's reasonable for this to be `!isu₁`
287286
u = Core.ifelse(isu₁, u₁, 1)
288287
tup = Expr(:tuple)
289288
for t 0:u₂-1
290-
mvar = Symbol(variable_name(opp, t), '_', u)
291-
push!(tup.args, mvar)
289+
push!(tup.args, Symbol(variable_name(opp, ifelse(isu₂, t, -1)), '_', u))
292290
end
293291
vut = Expr(:call, lv(:VecUnroll), tup) # `VecUnroll` of `VecUnroll`s
294292
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, 0, ls)

src/codegen/lowering.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -873,33 +873,34 @@ function lower_unrollspec(ls::LoopSet)
873873
end
874874

875875
function lower(ls::LoopSet, order, u₁loop, u₂loop, vectorized, u₁, u₂, inline::Bool)
876-
cacheunrolled!(ls, u₁loop, u₂loop, vectorized)
877-
fillorder!(ls, order, u₁loop, u₂loop, u₂, vectorized)
878-
ls.unrollspecification = UnrollSpecification(ls, u₁loop, u₂loop, vectorized, u₁, u₂)
879-
q = lower_unrollspec(ls)
880-
inline && pushfirst!(q.args, Expr(:meta, :inline))
881-
q
876+
cacheunrolled!(ls, u₁loop, u₂loop, vectorized)
877+
fillorder!(ls, order, u₁loop, u₂loop, u₂, vectorized)
878+
ls.unrollspecification = UnrollSpecification(ls, u₁loop, u₂loop, vectorized, u₁, u₂)
879+
q = lower_unrollspec(ls)
880+
inline && pushfirst!(q.args, Expr(:meta, :inline))
881+
q
882882
end
883883

884884
function lower(ls::LoopSet, inline::Int = -1)
885-
fill_offset_memop_collection!(ls)
886-
order, u₁loop, u₂loop, vectorized, u₁, u₂, c, shouldinline = choose_order_cost(ls)
887-
lower(ls, order, u₁loop, u₂loop, vectorized, u₁, u₂, inlinedecision(inline, shouldinline))
885+
fill_offset_memop_collection!(ls)
886+
order, u₁loop, u₂loop, vectorized, u₁, u₂, c, shouldinline = choose_order_cost(ls)
887+
lower(ls, order, u₁loop, u₂loop, vectorized, u₁, u₂, inlinedecision(inline, shouldinline))
888888
end
889889
function lower(ls::LoopSet, u₁::Int, u₂::Int, inline::Int)
890-
fill_offset_memop_collection!(ls)
891-
if u₂ > 1
892-
@assert num_loops(ls) > 1 "There is only $(num_loops(ls)) loop, but specified blocking parameter u₂ is $u₂."
893-
order, u₁loop, u₂loop, vectorized, _u₁, _u₂, c, shouldinline = choose_tile(ls)
894-
copyto!(ls.loop_order.bestorder, order)
895-
else
896-
u₂ = -1
897-
order, vectorized, c = choose_unroll_order(ls, Inf)
898-
u₁loop = first(order); u₂loop = Symbol("##undefined##"); shouldinline = true
899-
copyto!(ls.loop_order.bestorder, order)
900-
end
901-
doinline = inlinedecision(inline, shouldinline)
902-
lower(ls, order, u₁loop, u₂loop, vectorized, u₁, u₂, doinline)
890+
fill_offset_memop_collection!(ls)
891+
fill_children!(ls)
892+
if u₂ > 1
893+
@assert num_loops(ls) > 1 "There is only $(num_loops(ls)) loop, but specified blocking parameter u₂ is $u₂."
894+
order, u₁loop, u₂loop, vectorized, _u₁, _u₂, c, shouldinline = choose_tile(ls)
895+
copyto!(ls.loop_order.bestorder, order)
896+
else
897+
u₂ = -1
898+
order, vectorized, c = choose_unroll_order(ls, Inf)
899+
u₁loop = first(order); u₂loop = Symbol("##undefined##"); shouldinline = true
900+
copyto!(ls.loop_order.bestorder, order)
901+
end
902+
doinline = inlinedecision(inline, shouldinline)
903+
lower(ls, order, u₁loop, u₂loop, vectorized, u₁, u₂, doinline)
903904
end
904905

905906
# Base.convert(::Type{Expr}, ls::LoopSet) = lower(ls)

src/condense_loopset.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ function generate_call_types(
645645
for op ops
646646
instr::Instruction = instruction(op)
647647
if (isconstant(op) && (instr == LOOPCONSTANT)) && (!roots[identifier(op)])
648-
instr = op.instruction = DROPPEDCONSTANT
648+
instr = op.instruction = DROPPEDCONSTANT
649649
end
650650
push!(operation_descriptions.args, QuoteNode(instr.mod))
651651
push!(operation_descriptions.args, QuoteNode(instr.instr))
@@ -790,7 +790,7 @@ function setup_call_debug(ls::LoopSet)
790790
generate_call(ls, (false,zero(Int8),zero(Int8)), zero(UInt), true)
791791
end
792792
function setup_call(
793-
ls::LoopSet, q::Expr, source::LineNumberNode, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, thread::Int, warncheckarg::Bool
793+
ls::LoopSet, q::Expr, source::LineNumberNode, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, thread::Int, warncheckarg::Int
794794
)
795795
# We outline/inline at the macro level by creating/not creating an anonymous function.
796796
# The old API instead was based on inlining or not inline the generated function, but
@@ -802,7 +802,11 @@ function setup_call(
802802
call = generate_call(ls, (inline, u₁, u₂), thread%UInt, false)
803803
call = check_empty ? check_if_empty(ls, call) : call
804804
argfailure = make_crashy(make_fast(q))
805-
warncheckarg && (argfailure = Expr(:block, :(@warn "`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead." maxlog=1), argfailure))
805+
if warncheckarg 0
806+
warning = :(@warn "`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.")
807+
warncheckarg > 0 && push!(warning.args, :(maxlog=$warncheckarg))
808+
argfailure = Expr(:block, warning, argfailure)
809+
end
806810
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
807811
prepend_lnns!(ls.prepreamble, lnns)
808812
return ls.prepreamble

src/constructors.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function add_ci_call!(q::Expr, @nospecialize(f), args, syms, i, valarg = nothing
3535
push!(q.args, Expr(:(=), syms[i], call))
3636
end
3737

38-
function substitute_broadcast(q::Expr, mod::Symbol, inline, u₁, u₂, threads, warncheckarg)
38+
function substitute_broadcast(q::Expr, mod::Symbol, inline::Bool, u₁::Int8, u₂::Int8, threads::Int, warncheckarg::Int)
3939
ci = first(Meta.lower(LoopVectorization, q).args).code
4040
nargs = length(ci)-1
4141
ex = Expr(:block)
@@ -75,7 +75,7 @@ function loopset(q::Expr) # for interactive use only
7575
ls
7676
end
7777

78-
function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, threads::Int, warncheckarg::Bool)
78+
function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, threads::Int, warncheckarg::Int)
7979
((arg.head === :(=)) && (length(arg.args) == 2)) || throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
8080
kw = (arg.args[1])::Symbol
8181
value = (arg.args[2])
@@ -101,29 +101,28 @@ function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u
101101
throw(ArgumentError("Don't know how to process argument in `thread=$value`."))
102102
end
103103
elseif kw === :warn_check_args
104-
warncheckarg = value::Bool
104+
warncheckarg = convert(Int, value)::Int
105105
else
106106
throw(ArgumentError("Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`."))
107107
end
108108
inline, check_empty, u₁, u₂, threads, warncheckarg
109109
end
110-
function process_args(args; inline = false, check_empty = false, u₁ = zero(Int8), u₂ = zero(Int8), threads = 1, warncheckarg = false)
111-
for arg args
112-
inline, check_empty, u₁, u₂, threads, warncheckarg = check_macro_kwarg(arg, inline, check_empty, u₁, u₂, threads, warncheckarg)
113-
end
114-
inline, check_empty, u₁, u₂, threads, warncheckarg
110+
function process_args(args; inline::Bool = false, check_empty::Bool = false, u₁::Int8 = zero(Int8), u₂::Int8 = zero(Int8), threads::Int = 1, warncheckarg::Int = 0)
111+
for arg args
112+
inline, check_empty, u₁, u₂, threads, warncheckarg = check_macro_kwarg(arg, inline, check_empty, u₁, u₂, threads, warncheckarg)
113+
end
114+
inline, check_empty, u₁, u₂, threads, warncheckarg
115115
end
116116
function turbo_macro(mod, src, q, args...)
117-
q = macroexpand(mod, q)
118-
119-
if q.head === :for
120-
ls = LoopSet(q, mod)
121-
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args)
122-
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, threads, warncheckarg))
123-
else
124-
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args, inline=true)
125-
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, threads, warncheckarg)
126-
end
117+
q = macroexpand(mod, q)
118+
if q.head === :for
119+
ls = LoopSet(q, mod)
120+
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args)
121+
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, threads, warncheckarg))
122+
else
123+
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args, inline=true)
124+
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, threads, warncheckarg)
125+
end
127126
end
128127
"""
129128
@turbo
@@ -215,6 +214,8 @@ use their `parent`. Triangular loops aren't yet supported.
215214
Setting the keyword argument `warn_check_args=true` (e.g. `@turbo warn_check_args=true for ...`) in a loop or
216215
broadcast statement will cause it to warn once if `LoopVectorization.check_args` fails and the fallback
217216
loop is executed instead of the LoopVectorization-optimized loop.
217+
Setting it to an integer > 0 will warn that many times, while setting it to a negative integer will warn
218+
an unlimited amount of times. The default is `warn_check_args = 0`.
218219
"""
219220
macro turbo(args...)
220221
turbo_macro(__module__, __source__, last(args), Base.front(args)...)
@@ -256,7 +257,7 @@ end
256257
macro _turbo(arg, q)
257258
@assert q.head === :for
258259
q = macroexpand(__module__, q)
259-
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), 1, false)
260+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), 1, 0)
260261
ls = LoopSet(q, __module__)
261262
set_hw!(ls)
262263
def_outer_reduct_types!(ls)

0 commit comments

Comments
 (0)