Skip to content

Commit 876ac51

Browse files
committed
Split ifelse(loopconst,x,y)
1 parent 6f84d62 commit 876ac51

File tree

4 files changed

+113
-54
lines changed

4 files changed

+113
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.14"
4+
version = "0.12.15"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/condense_loopset.jl

Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ function getroots!(rooted::Vector{Bool}, ls::LoopSet)
137137
for op ops
138138
isstore(op) && recursively_set_parents_true!(rooted, op)
139139
end
140+
length(ls.includedactualarrays) == 0 || remove_outer_reducts!(rooted, ls)
140141
return rooted
141142
end
142143
function OperationStruct!(varnames::Vector{Symbol}, ids::Vector{Int}, ls::LoopSet, op::Operation)
@@ -461,28 +462,79 @@ function remove_outer_reducts!(roots::Vector{Bool}, ls::LoopSet)
461462
end
462463
end
463464

464-
# function generate_call_split(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool = false)
465-
# ops = operations(ls)
466-
# for op ∈ ops
467-
# if (iscompute(op) && (instruction(op).instr === :ifelse)) && iszero(length(loopdependencies(first(parents(op)))))
468-
# # we want to eliminate
469-
# end
470-
# end
471-
# end
472465

473-
# Try to condense in type stable manner
474-
function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool = false)
475-
extra_args = Expr(:tuple)
476-
preserve, shouldindbyind, roots = add_grouped_strided_pointer!(extra_args, ls)
477466

467+
function split_ifelse!(
468+
ls::LoopSet, preserve::Vector{Symbol}, shouldindbyind::Vector{Bool}, roots::Vector{Bool}, extra_args::Expr, k::Int, inlineu₁u₂::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool
469+
)
470+
roots[k] = false
471+
op = operations(ls)[k]
472+
op.instruction = DROPPEDCONSTANT
473+
op.node_type = constant
474+
# we want to eliminate
475+
parents_op = parents(op)
476+
condop = first(parents_op)
477+
# create one loop where `opp` is true, and a second where it is `false`
478+
prepre = ls.prepreamble; append!(prepre.args, ls.preamble.args)
479+
ls.prepreamble = Expr(:block); ls.preamble = Expr(:block);
480+
ls_true = deepcopy(ls)
481+
lsfalse = ls
482+
true_ops = operations(ls_true)
483+
falseops = operations(lsfalse)
484+
true_op = parents(true_ops[k])[2]
485+
falseop = parents_op[3]
486+
condop_count = 0
487+
for i eachindex(falseops)
488+
fop = falseops[i]
489+
parents_false = parents(fop)
490+
for (j,opp) enumerate(parents_false)
491+
if opp === op # then ops[i]'s jth parent is the ifelse
492+
parents(true_ops[i])[j] = true_op
493+
parents_false[j] = falseop
494+
end
495+
condop_count += roots[i] & (condop === opp)
496+
end
497+
end
498+
roots[identifier(condop)] &= condop_count > 0
499+
q = :(if $(name(condop))
500+
$(generate_call_split(ls_true, preserve, shouldindbyind, roots, copy(extra_args), inlineu₁u₂, thread, debug))
501+
else
502+
$(generate_call_split(lsfalse, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug))
503+
end)
504+
push!(prepre.args, q)
505+
prepre
506+
end
507+
508+
function generate_call(ls::LoopSet, inlineu₁u₂::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool)
509+
extra_args = Expr(:tuple)
510+
preserve, shouldindbyind, roots = add_grouped_strided_pointer!(extra_args, ls)
511+
generate_call_split(ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
512+
end
513+
function generate_call_split(
514+
ls::LoopSet, preserve::Vector{Symbol}, shouldindbyind::Vector{Bool}, roots::Vector{Bool}, extra_args::Expr, inlineu₁u₂::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool
515+
)
516+
if !debug
517+
for (k,op) enumerate(operations(ls))
518+
parents_op = parents(op)
519+
if (iscompute(op) && (instruction(op).instr === :ifelse)) && (length(parents_op) == 3) && isconstantop(first(parents_op))
520+
return split_ifelse!(ls, preserve, shouldindbyind, roots, extra_args, k, inlineu₁u₂, thread, debug)
521+
end
522+
end
523+
end
524+
return generate_call_types(ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
525+
end
526+
# Try to condense in type stable manner
527+
function generate_call_types(
528+
ls::LoopSet, preserve::Vector{Symbol}, shouldindbyind::Vector{Bool}, roots::Vector{Bool}, extra_args::Expr, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool
529+
)
530+
# good place to check for split
478531
operation_descriptions = Expr(:tuple)
479532
varnames = Symbol[]; ids = Vector{Int}(undef, length(operations(ls)))
480533
ops = operations(ls)
481-
length(ls.includedactualarrays) == 0 || remove_outer_reducts!(roots, ls)
482534
for op ops
483535
instr::Instruction = instruction(op)
484536
if (isconstant(op) && (instr == LOOPCONSTANT)) && (!roots[identifier(op)])
485-
instr = op.instruction = DROPPEDCONSTANT
537+
instr = op.instruction = DROPPEDCONSTANT
486538
end
487539
push!(operation_descriptions.args, QuoteNode(instr.mod))
488540
push!(operation_descriptions.args, QuoteNode(instr.instr))
@@ -505,7 +557,7 @@ function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, t
505557
configarg = (inline,u₁,u₂,ls.isbroadcast,thread)
506558
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
507559
q = Expr(:call, func, unroll_param_tup, val(operation_descriptions), val(arrayref_descriptions), val(argmeta), val(loop_syms))
508-
560+
509561
add_reassigned_syms!(extra_args, ls) # counterpart to `add_ops!` constants
510562
for (opid,sym) ls.preamble_symsym # counterpart to process_metadata! symsym extraction
511563
if instruction(ops[opid]) DROPPEDCONSTANT
@@ -517,7 +569,13 @@ function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, t
517569
push!(q.args, Expr(:tuple, lbarg, extra_args))
518570
vecwidthdefq = Expr(:block)
519571
define_eltype_vec_width!(vecwidthdefq, ls, nothing)
520-
Expr(:block, vecwidthdefq, q), preserve
572+
push!(vecwidthdefq.args, q)
573+
if debug
574+
pushpreamble!(ls,vecwidthdefq)
575+
Expr(:block, ls.prepreamble, ls.preamble)
576+
else
577+
setup_call_final(ls, setup_outerreduct_preserve(ls, vecwidthdefq, preserve))
578+
end
521579
end
522580

523581

@@ -581,34 +639,32 @@ function gc_preserve(call::Expr, preserve::Vector{Symbol})
581639
q
582640
end
583641

584-
function setup_call_inline(ls::LoopSet, inline::Bool, u₁::Int8, u₂::Int8, thread::Int)
585-
call, preserve = generate_call(ls, (inline,u₁,u₂), thread % UInt, false)
586-
if iszero(length(ls.outer_reductions))
587-
pushpreamble!(ls, gc_preserve(call, preserve))
588-
push!(ls.preamble.args, nothing)
589-
return ls.preamble
590-
end
591-
retv = loopset_return_value(ls, Val(false))
592-
outer_reducts = Expr(:local)
593-
q = Expr(:block,gc_preserve(Expr(:(=), retv, call), preserve))
594-
for or ls.outer_reductions
595-
op = ls.operations[or]
596-
var = name(op)
597-
# push!(call.args, Symbol("##TYPEOF##", var))
598-
mvar = mangledvar(op)
599-
instr = instruction(op)
600-
out = Symbol(mvar, "##onevec##")
601-
push!(outer_reducts.args, out)
602-
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Expr(:call, lv(:vecmemaybe), out), var)))
603-
end
604-
pushpreamble!(ls, outer_reducts)
605-
append!(ls.preamble.args, q.args)
606-
ls.preamble
642+
# function setup_call_inline(ls::LoopSet, inline::Bool, u₁::Int8, u₂::Int8, thread::Int)
643+
# call, preserve = generate_call_split(ls, (inline,u₁,u₂), thread % UInt, false)
644+
# setup_call_ret!(ls, call, preserve)
645+
# end
646+
function setup_outerreduct_preserve(ls::LoopSet, call::Expr, preserve::Vector{Symbol})
647+
iszero(length(ls.outer_reductions)) && return gc_preserve(call, preserve)
648+
retv = loopset_return_value(ls, Val(false))
649+
q = Expr(:block, gc_preserve(Expr(:(=), retv, call), preserve))
650+
for or ls.outer_reductions
651+
op = ls.operations[or]
652+
var = name(op)
653+
# push!(call.args, Symbol("##TYPEOF##", var))
654+
mvar = mangledvar(op)
655+
instr = instruction(op)
656+
out = Symbol(mvar, "##onevec##")
657+
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Expr(:call, lv(:vecmemaybe), out), var)))
658+
end
659+
q
660+
end
661+
function setup_call_final(ls::LoopSet, q::Expr)
662+
pushpreamble!(ls, q)
663+
push!(ls.preamble.args, nothing)
664+
return ls.preamble
607665
end
608666
function setup_call_debug(ls::LoopSet)
609-
# avx_loopset(instr, ops, arf, AM, LB, vargs)
610-
pushpreamble!(ls, first(generate_call(ls, (false,zero(Int8),zero(Int8)), zero(UInt), true)))
611-
Expr(:block, ls.prepreamble, ls.preamble)
667+
generate_call(ls, (false,zero(Int8),zero(Int8)), zero(UInt), true)
612668
end
613669
function setup_call(
614670
ls::LoopSet, q::Expr, source::LineNumberNode, inline::Bool, check_empty::Bool, u₁::Int8, u₂::Int8, thread::Int
@@ -620,9 +676,9 @@ function setup_call(
620676
# inlining the generated function into the loop preamble.
621677
lnns = extract_all_lnns(q)
622678
pushfirst!(lnns, source)
623-
call = setup_call_inline(ls, inline, u₁, u₂, thread)
679+
call = generate_call(ls, (inline, u₁, u₂), thread%UInt, false)
624680
call = check_empty ? check_if_empty(ls, call) : call
625-
result = Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), call, make_crashy(make_fast(q))))
626-
prepend_lnns!(result, lnns)
627-
return result
681+
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, make_crashy(make_fast(q))))
682+
prepend_lnns!(ls.prepreamble, lnns)
683+
return ls.prepreamble
628684
end

src/modeling/graphs.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,17 @@ This is used so that identical loops will create identical `_avx_!` calls in the
594594
"""
595595
gensym!(ls::LoopSet, s) = Symbol("###$(s)###$(ls.symcounter += 1)###")
596596

597+
function fill_children!(ls::LoopSet)
598+
for op operations(ls)
599+
empty!(children(op))
600+
for opp parents(op)
601+
push!(children(opp), op)
602+
end
603+
end
604+
end
597605
function cacheunrolled!(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, vloopsym::Symbol)
606+
fill_children!(ls)
598607
vloop = getloop(ls, vloopsym)
599-
for op operations(ls)
600-
empty!(children(op))
601-
for opp parents(op)
602-
push!(children(opp), op)
603-
end
604-
end
605608
for op operations(ls)
606609
setunrolled!(ls, op, u₁loop, u₂loop, vloopsym)
607610
if accesses_memory(op)

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function _precompile_()
7373
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 2, 1, 0, (1, 2), Tuple{StaticInt{8}, StaticInt{72}}, Tuple{StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{1}, StaticInt{1}}}) # time: 0.003950459
7474
Base.precompile(Tuple{typeof(add_operation!),LoopSet,Symbol,Expr,Int,Int}) # time: 0.003884223
7575
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 4, 1, 0, (1, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int}, NTuple{4, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{1, Tuple{StaticInt{1}}}, VectorizationBase.CartesianVIndex{3, Tuple{StaticInt{1}, StaticInt{1}, StaticInt{1}}}}}) # time: 0.003860693
76-
Base.precompile(Tuple{typeof(setup_call_inline),LoopSet,Bool,Int8,Int8,Int}) # time: 0.003833778
76+
# Base.precompile(Tuple{typeof(setup_call_inline),LoopSet,Bool,Int8,Int8,Int}) # time: 0.003833778
7777
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 2, 1, 0, (1, 2), Tuple{StaticInt{8}, StaticInt{64}}, Tuple{StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{1}, StaticInt{1}}}) # time: 0.003823736
7878
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 2, 1, 0, (1, 2), Tuple{StaticInt{8}, StaticInt{120}}, Tuple{StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{1}, StaticInt{1}}}) # time: 0.003811805
7979
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 2, 1, 0, (1, 2), Tuple{StaticInt{8}, StaticInt{104}}, Tuple{StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{1}, StaticInt{1}}}) # time: 0.003737383

0 commit comments

Comments
 (0)