Skip to content

Commit 9aa0211

Browse files
committed
Add warn_check_args=true option to @turbo so that it can emit a warning when LoopVectorization.check_args fails. Resolves #296
1 parent 649817b commit 9aa0211

File tree

8 files changed

+96
-70
lines changed

8 files changed

+96
-70
lines changed

src/broadcast.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ end
398398
# need to construct the LoopSet
399399
# @show typeof(dest)
400400
ls = LoopSet(Mod)
401-
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads = UNROLL
401+
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads, warncheckarg = UNROLL
402402
set_hw!(ls, rs, rc, cls, l1, l2, l3)
403403
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
404404
loopsyms = [gensym!(ls, "n") for n 1:N]
@@ -409,7 +409,7 @@ end
409409
doaddref!(ls, storeop)
410410
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
411411
# return ls
412-
sc = setup_call(ls, :(Base.Broadcast.materialize!(dest, bc)), LineNumberNode(0), inline, false, u₁, u₂, threads%Int)
412+
sc = setup_call(ls, :(Base.Broadcast.materialize!(dest, bc)), LineNumberNode(0), inline, false, u₁, u₂, threads%Int, warncheckarg)
413413
# return sc
414414
Expr(:block, Expr(:meta,:inline), sc, :dest)
415415
end
@@ -419,7 +419,7 @@ end
419419
# we have an N dimensional loop.
420420
# need to construct the LoopSet
421421
ls = LoopSet(Mod)
422-
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads = UNROLL
422+
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads, warncheckarg = UNROLL
423423
set_hw!(ls, rs, rc, cls, l1, l2, l3)
424424
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
425425
loopsyms = [gensym!(ls, "n") for n 1:N]
@@ -430,7 +430,7 @@ end
430430
storeop = add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)
431431
doaddref!(ls, storeop)
432432
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
433-
Expr(:block, Expr(:meta,:inline), setup_call(ls, :(Base.Broadcast.materialize!(dest′, bc)), LineNumberNode(0), inline, false, u₁, u₂, threads%Int), :dest′)
433+
Expr(:block, Expr(:meta,:inline), setup_call(ls, :(Base.Broadcast.materialize!(dest′, bc)), LineNumberNode(0), inline, false, u₁, u₂, threads%Int, warncheckarg), :dest′)
434434
end
435435
# these are marked `@inline` so the `@turbo` itself can choose whether or not to inline.
436436
@generated function vmaterialize!(

src/condense_loopset.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,11 @@ end
525525
::Val{CNFARG}, ::StaticInt{W}, ::StaticInt{RS}, ::StaticInt{AR}, ::StaticInt{NT},
526526
::StaticInt{CLS}, ::StaticInt{L1}, ::StaticInt{L2}, ::StaticInt{L3}
527527
) where {CNFARG,W,RS,AR,CLS,L1,L2,L3,NT}
528-
inline,u₁,u₂,BROADCAST,thread = CNFARG
529-
nt = min(thread % UInt, NT % UInt)
530-
t = Expr(:tuple, inline, u₁, u₂, BROADCAST, W, RS, AR, CLS, L1,L2,L3, nt)
531-
Expr(:call, Expr(:curly, :Val, t))
528+
inline,u₁,u₂,BROADCAST,thread = CNFARG
529+
nt = min(thread % UInt, NT % UInt)
530+
t = Expr(:tuple, inline, u₁, u₂, BROADCAST, W, RS, AR, CLS, L1,L2,L3, nt)
531+
length(CNFARG) == 6 && push!(t.args, last(CNFARG))
532+
Expr(:call, Expr(:curly, :Val, t))
532533
end
533534
@inline function avx_config_val(
534535
::Val{CNFARG}, ::StaticInt{W}
@@ -787,18 +788,20 @@ function setup_call_debug(ls::LoopSet)
787788
generate_call(ls, (false,zero(Int8),zero(Int8)), zero(UInt), true)
788789
end
789790
function setup_call(
790-
ls::LoopSet, q::Expr, source::LineNumberNode, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, thread::Int
791+
ls::LoopSet, q::Expr, source::LineNumberNode, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, thread::Int, warncheckarg::Bool
791792
)
792-
# We outline/inline at the macro level by creating/not creating an anonymous function.
793-
# The old API instead was based on inlining or not inline the generated function, but
794-
# the generated function must be inlined into the initial loop preamble for performance reasons.
795-
# Creating an anonymous function and calling it also achieves the outlining, while still
796-
# inlining the generated function into the loop preamble.
797-
lnns = extract_all_lnns(q)
798-
pushfirst!(lnns, source)
799-
call = generate_call(ls, (inline, u₁, u₂), thread%UInt, false)
800-
call = check_empty ? check_if_empty(ls, call) : call
801-
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, make_crashy(make_fast(q))))
802-
prepend_lnns!(ls.prepreamble, lnns)
803-
return ls.prepreamble
793+
# We outline/inline at the macro level by creating/not creating an anonymous function.
794+
# The old API instead was based on inlining or not inline the generated function, but
795+
# the generated function must be inlined into the initial loop preamble for performance reasons.
796+
# Creating an anonymous function and calling it also achieves the outlining, while still
797+
# inlining the generated function into the loop preamble.
798+
lnns = extract_all_lnns(q)
799+
pushfirst!(lnns, source)
800+
call = generate_call(ls, (inline, u₁, u₂), thread%UInt, false)
801+
call = check_empty ? check_if_empty(ls, call) : call
802+
argfailure = make_crashy(make_fast(q))
803+
warncheckarg && (argfailure = Expr(:block, :(@warn "`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead." maxlog=1), argfailure))
804+
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
805+
prepend_lnns!(ls.prepreamble, lnns)
806+
return ls.prepreamble
804807
end

src/constructors.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ function add_ci_call!(q::Expr, @nospecialize(f), args, syms, i, valarg = nothing
3535
push!(q.args, Expr(:(=), syms[i], call))
3636
end
3737

38-
function substitute_broadcast(q::Expr, mod::Symbol, inline, u₁, u₂, threads)
38+
function substitute_broadcast(q::Expr, mod::Symbol, inline, u₁, u₂, threads, warncheckarg)
3939
ci = first(Meta.lower(LoopVectorization, q).args).code
4040
nargs = length(ci)-1
4141
ex = Expr(:block)
4242
syms = [gensym() for _ 1:nargs]
43-
configarg = (inline,u₁,u₂,true,threads)
43+
configarg = (inline,u₁,u₂,true,threads,warncheckarg)
4444
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
4545
for n 1:nargs
4646
ciₙ = ci[n]
@@ -75,7 +75,7 @@ function loopset(q::Expr) # for interactive use only
7575
ls
7676
end
7777

78-
function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, threads::Int)
78+
function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, threads::Int, warncheckarg::Bool)
7979
((arg.head === :(=)) && (length(arg.args) == 2)) || throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
8080
kw = (arg.args[1])::Symbol
8181
value = (arg.args[2])
@@ -100,27 +100,29 @@ function check_macro_kwarg(arg, inline::Bool, check_empty::Bool, u₁::Int8, u
100100
else
101101
throw(ArgumentError("Don't know how to process argument in `thread=$value`."))
102102
end
103+
elseif kw === :warn_check_args
104+
warncheckarg = value::Bool
103105
else
104106
throw(ArgumentError("Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`."))
105107
end
106-
inline, check_empty, u₁, u₂, threads
108+
inline, check_empty, u₁, u₂, threads, warncheckarg
107109
end
108-
function process_args(args; inline = false, check_empty = false, u₁ = zero(Int8), u₂ = zero(Int8), threads = 1)
110+
function process_args(args; inline = false, check_empty = false, u₁ = zero(Int8), u₂ = zero(Int8), threads = 1, warncheckarg = false)
109111
for arg args
110-
inline, check_empty, u₁, u₂, threads = check_macro_kwarg(arg, inline, check_empty, u₁, u₂, threads)
112+
inline, check_empty, u₁, u₂, threads, warncheckarg = check_macro_kwarg(arg, inline, check_empty, u₁, u₂, threads, warncheckarg)
111113
end
112-
inline, check_empty, u₁, u₂, threads
114+
inline, check_empty, u₁, u₂, threads, warncheckarg
113115
end
114116
function turbo_macro(mod, src, q, args...)
115117
q = macroexpand(mod, q)
116118

117119
if q.head === :for
118120
ls = LoopSet(q, mod)
119-
inline, check_empty, u₁, u₂, threads = process_args(args)
120-
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, threads))
121+
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args)
122+
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, threads, warncheckarg))
121123
else
122-
inline, check_empty, u₁, u₂, threads = process_args(args, inline=true)
123-
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, threads)
124+
inline, check_empty, u₁, u₂, threads, warncheckarg = process_args(args, inline=true)
125+
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, threads, warncheckarg)
124126
end
125127
end
126128
"""
@@ -209,6 +211,10 @@ and `@fastmath` is generated. Note that `VectorizationBase` provides functions s
209211
ignore `@fastmath`, preserving IEEE semantics both within `@turbo` and `@fastmath`.
210212
`check_args` currently returns false for some wrapper types like `LinearAlgebra.UpperTriangular`, requiring you to
211213
use their `parent`. Triangular loops aren't yet supported.
214+
215+
Setting the keyword argument `warn_check_args=true` (e.g. `@turbo warn_check_args=true for ...`) in a loop or
216+
broadcast statement will cause it to warn once if `LoopVectorization.check_args` fails and the fallback
217+
loop is executed instead of the LoopVectorization-optimized loop.
212218
"""
213219
macro turbo(args...)
214220
turbo_macro(__module__, __source__, last(args), Base.front(args)...)
@@ -250,7 +256,7 @@ end
250256
macro _turbo(arg, q)
251257
@assert q.head === :for
252258
q = macroexpand(__module__, q)
253-
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), 1)
259+
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), 1, false)
254260
ls = LoopSet(q, __module__)
255261
set_hw!(ls)
256262
def_outer_reduct_types!(ls)

src/modeling/graphs.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -989,12 +989,13 @@ end
989989

990990

991991
function maybe_const_compute!(ls::LoopSet, LHS::Symbol, op::Operation, elementbytes::Int, position::Int)
992-
# return op
993-
if iscompute(op) && iszero(length(loopdependencies(op)))
994-
ls.opdict[LHS] = add_constant!(ls, LHS, ls.loopsymbols[1:position], gensym!(ls, instruction(op).instr), elementbytes, :numericconstant)
995-
else
996-
op
997-
end
992+
# return op
993+
if iscompute(op) && iszero(length(loopdependencies(op)))
994+
ls.opdict[LHS] = add_constant!(ls, LHS, ls.loopsymbols[1:position], gensym!(ls, instruction(op).instr), elementbytes, :numericconstant)
995+
else
996+
# op.dependencies = ls.loopsymbols[1:position]
997+
op
998+
end
998999
end
9991000
strip_op_linenumber_nodes(q::Expr) = only(filter(x -> !isa(x, LineNumberNode), q.args))
10001001

src/parse/add_compute.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,12 @@ function add_reduction_update_parent!(
183183
parent::Operation, instr::Instruction, reduction_ind::Int, elementbytes::Int
184184
)
185185
var = name(parent)
186+
# isouterreduction = iszero(length(loopdependencies(parent))) && (parent.instruction === LOOPCONSTANT)
186187
isouterreduction = parent.instruction === LOOPCONSTANT
187188
# @show instr, vparents, parent, reduction_ind
188189
# if parent is not an outer reduction...
189190
# if !isouterreduction && !isreductzero(parent, ls, reduct_zero)
191+
# @show isouterreduction, parent, length(loopdependencies(parent))
190192
add_reduct_instruct = !isouterreduction && !isconstant(parent)
191193
if add_reduct_instruct
192194
if instr.instr === :ifelse
@@ -226,7 +228,7 @@ function add_reduction_update_parent!(
226228
update_reduction_status!(vparents, reduceddeps, name(reductinit))
227229
# this is the op added by add_compute
228230
op = Operation(length(operations(ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, vparents)
229-
parent.instruction === LOOPCONSTANT && push!(ls.outer_reductions, identifier(op))
231+
isouterreduction && push!(ls.outer_reductions, identifier(op))
230232
opout = pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
231233
# isouterreduction || iszero(length(reduceddeps)) && return opout
232234
# return opout

src/parse/add_constants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
# assignedsym will be assigned to value within the preamble
6363
function add_constant!(
6464
ls::LoopSet, value::Symbol, deps::Vector{Symbol}, assignedsym::Symbol, elementbytes::Int, f::Symbol = Symbol("")
65-
)
65+
)
6666
retop = get(ls.opdict, value, nothing)
6767
if retop === nothing
6868
op = Operation(length(operations(ls)), assignedsym, elementbytes, Instruction(f, value), constant, deps, NODEPENDENCY, NOPARENTS)

src/parse/add_ifelse.jl

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,49 @@
55
negateop!(ls::LoopSet, condop::Operation, elementbytes::Int) = add_compute!(ls, gensym!(ls, "negated#mask"), :~, [condop], elementbytes)
66

77
function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int, mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing)
8-
# for now, just simple 1-liners
9-
@assert length(RHS.args) == 3 "if statements without an else cannot be assigned to a variable."
10-
condition = first(RHS.args)
11-
condop = if condition isa Symbol
12-
getop(ls, condition, elementbytes)
13-
elseif mpref === nothing
14-
add_operation!(ls, gensym!(ls, "mask"), condition, elementbytes, position)
15-
else
16-
add_operation!(ls, gensym!(ls, "mask"), condition, mpref, elementbytes, position)
8+
# for now, just simple 1-liners
9+
@assert length(RHS.args) == 3 "if statements without an else cannot be assigned to a variable."
10+
condition = first(RHS.args)
11+
condop = if condition isa Symbol
12+
getop(ls, condition, elementbytes)
13+
elseif mpref === nothing
14+
add_operation!(ls, gensym!(ls, "mask"), condition, elementbytes, position)
15+
else
16+
add_operation!(ls, gensym!(ls, "mask"), condition, mpref, elementbytes, position)
17+
end
18+
iftrue = RHS.args[2]
19+
if iftrue isa Expr
20+
trueop = add_operation!(ls, gensym!(ls, "iftrue"), iftrue, elementbytes, position)
21+
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop)) && !search_tree(parents(condop), trueop)
22+
trueop.instruction = Instruction(:conditionalload)
23+
push!(parents(trueop), condop)
1724
end
18-
iftrue = RHS.args[2]
19-
if iftrue isa Expr
20-
trueop = add_operation!(ls, gensym!(ls, "iftrue"), iftrue, elementbytes, position)
21-
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop)) && !search_tree(parents(condop), trueop)
22-
trueop.instruction = Instruction(:conditionalload)
23-
push!(parents(trueop), condop)
24-
end
25-
else
26-
trueop = getop(ls, iftrue, elementbytes)
25+
else
26+
trueop = getop(ls, iftrue, elementbytes)
27+
end
28+
iffalse = RHS.args[3]
29+
if trueop.instruction === Instruction(:conditionalload)
30+
if ((iffalse isa Number) && (iffalse == 0)) || (Meta.isexpr(iffalse, :call, 2) && (iffalse.args[1] === :zero))
31+
trueop.variable = LHS
32+
ls.opdict[LHS] = trueop
33+
return trueop
2734
end
28-
iffalse = RHS.args[3]
29-
if iffalse isa Expr
30-
falseop = add_operation!(ls, gensym!(ls, "iffalse"), iffalse, elementbytes, position)
31-
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop)) && !search_tree(parents(condop), falseop)
32-
falseop.instruction = Instruction(:conditionalload)
33-
push!(parents(falseop), negateop!(ls, condop, elementbytes))
34-
end
35-
else
36-
falseop = getop(ls, iffalse, elementbytes)
35+
end
36+
if iffalse isa Expr
37+
falseop = add_operation!(ls, gensym!(ls, "iffalse"), iffalse, elementbytes, position)
38+
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop)) && !search_tree(parents(condop), falseop)
39+
falseop.instruction = Instruction(:conditionalload)
40+
push!(parents(falseop), negateop!(ls, condop, elementbytes))
41+
if (any(==(identifier(trueop)), Iterators.map(first, ls.preamble_zeros)))
42+
falseop.variable = LHS
43+
ls.opdict[LHS] = falseop
44+
return falseop
45+
end
3746
end
38-
add_compute_ifelse!(ls, LHS, condop, trueop, falseop, elementbytes)
47+
else
48+
falseop = getop(ls, iffalse, elementbytes)
49+
end
50+
add_compute_ifelse!(ls, LHS, condop, trueop, falseop, elementbytes)
3951
end
4052

4153
function add_andblock!(ls::LoopSet, condop::Operation, LHS, rhsop::Operation, elementbytes::Int, position::Int)

test/fallback.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
end
1212
function msdavx(x)
1313
s = zero(eltype(x))
14-
@turbo for i in eachindex(x)
14+
@turbo warn_check_args=true for i in eachindex(x)
1515
s = muladd(x[i], x[i], s) # Avoids fastmath in fallback loop.
1616
end
1717
s
@@ -33,6 +33,8 @@
3333
@test @inferred !LoopVectorization.check_args(['a'])
3434
@test @inferred !LoopVectorization.check_args(Diagonal(x))
3535

36+
@test_nowarn msdavx(x)
37+
@test_logs (:warn,"`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.") msdavx(FallbackArrayWrapper(x))
3638
@test msdavx(FallbackArrayWrapper(x)) == 1e18
3739
@test msd(x) == msdavx(FallbackArrayWrapper(x))
3840
@test msdavx(x) != msdavx(FallbackArrayWrapper(x))

0 commit comments

Comments
 (0)