Skip to content

Commit 09ff723

Browse files
committed
cse stores; assumes stores can be cse-ed
1 parent f560133 commit 09ff723

File tree

4 files changed

+160
-58
lines changed

4 files changed

+160
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/graphs.jl

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ struct LoopOrder <: AbstractArray{Vector{Operation},5}
5454
bestorder::Vector{Symbol}
5555
end
5656
function LoopOrder(N::Int)
57-
LoopOrder( [ Operation[] for i 1:24N ], Vector{Symbol}(undef, N), Vector{Symbol}(undef, N) )
57+
LoopOrder(
58+
[ Operation[] for _ 1:32N ],
59+
Vector{Symbol}(undef, N), Vector{Symbol}(undef, N)
60+
)
5861
end
5962
LoopOrder() = LoopOrder(Vector{Operation}[],Symbol[],Symbol[])
6063
Base.empty!(lo::LoopOrder) = foreach(empty!, lo.oporder)
@@ -87,6 +90,8 @@ struct LoopSet
8790
refs_aliasing_syms::Vector{ArrayReference}
8891
cost_vec::Matrix{Float64}
8992
reg_pres::Matrix{Int}
93+
included_vars::Vector{Bool}
94+
place_after_loop::Vector{Bool}
9095
# sym_to_ref_aliases::Dict{Symbol,ArrayReference}
9196
# ref_to_sym_aliases::Dict{ArrayReference,Symbol}
9297
end
@@ -139,7 +144,8 @@ function LoopSet()
139144
Symbol[],
140145
ArrayReference[],
141146
Matrix{Float64}(undef, 4, 2),
142-
Matrix{Int}(undef, 4, 2)
147+
Matrix{Int}(undef, 4, 2),
148+
Bool[], Bool[]
143149
)
144150
end
145151
num_loops(ls::LoopSet) = length(ls.loops)
@@ -246,16 +252,19 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
246252
lower = r.args[2]
247253
upper = r.args[3]
248254
lii::Bool = lower isa Integer
255+
@assert lii
256+
liiv::Int = convert(Int, lii)
257+
@assert liiv == 1 "Currently only loops starting from the first index are supported."
249258
uii::Bool = upper isa Integer
250259
if lii & uii
251-
Loop(itersym, 1 + convert(Int,upper) - convert(Int,lower))
260+
Loop(itersym, 1 + convert(Int,upper) - liiv)
252261
else
253262
N = gensym(Symbol(:loop, itersym))
254263
ex = if lii
255264
if lower == 1
256265
pushpreamble!(ls, Expr(:(=), N, upper))
257266
else
258-
pushpreamble!(ls, Expr(:(=), N, Expr(:call, :-, upper, lower - 1)))
267+
pushpreamble!(ls, Expr(:(=), N, Expr(:call, :-, upper, liiv - 1)))
259268
end
260269
else
261270
ex = if uii
@@ -437,6 +446,7 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
437446
@view(expr.args[2+offset:end]),
438447
Ref(false)
439448
)::ArrayReference
449+
# whether this finds load or store, we use that
440450
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
441451
if id === nothing
442452
add_load!( ls, gensym(:temporary), ref, elementbytes )
@@ -470,6 +480,7 @@ function add_reduction_update_parent!(
470480
)
471481
parent = getop(ls, var, elementbytes)
472482
setdiffv!(reduceddeps, deps, loopdependencies(parent))
483+
mergesetv!(reduceddependencies(parent), reduceddeps)
473484
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
474485
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
475486
parent.instruction === LOOPCONSTANT && push!(ls.outer_reductions, identifier(op))
@@ -502,6 +513,19 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
502513
pushop!(ls, op, var)
503514
end
504515
end
516+
function add_unique_store!(ls::LoopSet, ref::ArrayReference, parent::Operation, elementbytes::Int = 8)
517+
ldref = loopdependencies(ref, ls)
518+
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, ldref, reduceddependencies(parent), [parent], ref )
519+
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
520+
pushop!(ls, op, ref.array)
521+
end
522+
function cse_store!(ls::LoopSet, id::Int, ref::ArrayReference, parent::Operation, elementbytes::Int = 8)
523+
ldref = loopdependencies(ref, ls)
524+
op = Operation( length(operations(ls))-1, ref.array, elementbytes, :setindex!, memstore, ldref, reduceddependencies(parent), [parent], ref )
525+
ls.operations[id] = op
526+
ls.opdict[op.variable] = op
527+
op
528+
end
505529
function add_store!(
506530
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
507531
)
@@ -511,10 +535,16 @@ function add_store!(
511535
if pvar ls.syms_aliasing_refs
512536
push!(ls.syms_aliasing_refs, pvar)
513537
push!(ls.refs_aliasing_syms, ref)
538+
add_unique_store!(ls, ref, parent, elementbytes)
539+
else
540+
# try to cse store
541+
# different from cse load, because the other op here must be a store
542+
for opp operations(ls)
543+
isstore(opp) || continue
544+
ref == opp.ref && return cse_store!(ls, identifier(opp), ref, parent, elementbytes)
545+
end
546+
add_unique_store!(ls, ref, parent, elementbytes)
514547
end
515-
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, ldref, reduceddependencies(parent), [parent], ref )
516-
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
517-
pushop!(ls, op, ref.array)
518548
end
519549
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
520550
ref = ref_from_ref(ex)::ArrayReference
@@ -626,8 +656,10 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
626656
end
627657
ops = operations(ls)
628658
nops = length(ops)
629-
included_vars = fill(false, nops)
630-
place_after_loop = fill(true, nops)
659+
included_vars = resize!(ls.included_vars, nops)
660+
fill!(included_vars, false)
661+
place_after_loop = resize!(ls.place_after_loop, nops)
662+
fill!(ls.place_after_loop, true)
631663
# to go inside out, we just have to include all those not-yet included depending on the current sym
632664
empty!(lo)
633665
for _n 1:nloops
@@ -642,13 +674,23 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
642674
istiled = (loopistiled ? (tiled loopdependencies(op)) : false) + 1
643675
optype = Int(op.node_type) + 1
644676
after_loop = place_after_loop[id] + 1
645-
push!(lo[optype,isunrolled,istiled,after_loop,_n], ops[id])
677+
push!(lo[optype,isunrolled,istiled,after_loop,_n], op)
646678
set_upstream_family!(place_after_loop, op, false) # parents that have already been included are not moved, so no need to check included_vars to filter
647679
end
648680
end
649681
end
650682

651-
683+
function define_remaining_ops!(
684+
ls::LoopSet, vectorized::Symbol, W, unrolled, tiled, U::Int
685+
)
686+
ops = operations(ls)
687+
for (id,incl) enumerate(ls.included_vars)
688+
if !incl
689+
op = ops[id]
690+
length(reduceddependencies(op)) == 0 && lower!( ls.preamble, op, vectorized, W, unrolled, tiled, U, nothing, nothing )
691+
end
692+
end
693+
end
652694
# function depends_on_assigned(op::Operation, assigned::Vector{Bool})
653695
# for p ∈ parents(op)
654696
# p === op && continue # don't fall into recursive loop when we have updates, eg a = a + b

src/lowering.jl

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ function lower_load!(
167167
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, U::Int,
168168
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
169169
)
170-
# @show op.instruction
171-
# @show unrolled, loopdependencies(op)
172170
if vectorized loopdependencies(op)
173171
lower_load_vectorized!(q, op, vectorized, W, unrolled, U, suffix, mask)
174172
else
@@ -205,7 +203,6 @@ function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
205203
Uh = Uh2 >> 1
206204
reduce_range!(q, toreduct, instr, Uh, Uh2)
207205
Uh == 1 && break
208-
# @show Uh
209206
Uh2 = Uh
210207
iter += 1; iter > 4 && throw("Oops! This seems to be excessive unrolling.")
211208
end
@@ -220,11 +217,10 @@ pvariable_name(op::Operation, suffix) = Symbol(pvariable_name(op, nothing), suff
220217
function reduce_unroll!(q, op, U, unrolled)
221218
loopdeps = loopdependencies(op)
222219
isunrolled = unrolled loopdeps
223-
if unrolled reduceddependencies(op)
220+
if (unrolled reduceddependencies(op))
224221
U = isunrolled ? U : 1
225222
return U, isunrolled
226223
end
227-
unrolled reduceddependencies(op) || return U
228224
var = mangledvar(op)
229225
instr = first(parents(op)).instruction
230226
reduce_expr!(q, var, instr, U) # assigns reduction to storevar
@@ -278,7 +274,6 @@ function lower_store_vectorized!(
278274
for u 0:U-1
279275
name, mo = name_mo(var, op, u, W, vecnotunrolled, unrolled)
280276
instrcall = Expr(:call,lv(:vstore!), ptr, name, mo)
281-
# @show mask, vecnotunrolled, u, U
282277
if mask !== nothing && (vecnotunrolled || u == U - 1)
283278
push!(instrcall.args, mask)
284279
end
@@ -365,7 +360,6 @@ function lower_compute!(
365360
# cache unroll and tiling check of parents
366361
# not broadcasted, because we use frequent checks of individual bools
367362
# making BitArrays inefficient.
368-
# @show instr parentsunrolled
369363
# parentsyms = [opp.variable for opp ∈ parents(op)]
370364
Uiter = opunrolled ? U - 1 : 0
371365
maskreduct = mask !== nothing && isreduction(op) && any(opp -> opp.variable === var, parents_op)
@@ -401,27 +395,13 @@ function lower_compute!(
401395
end
402396
push!(instrcall.args, parent)
403397
end
404-
if maskreduct && u == Uiter # only mask last
398+
if maskreduct && (u == Uiter || unrolled !== vectorized) # only mask last
405399
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vifelse), mask, instrcall, varsym)))
406400
else
407401
push!(q.args, Expr(:(=), varsym, instrcall))
408402
end
409403
end
410404
end
411-
function lower!(
412-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
413-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
414-
)
415-
if isload(op)
416-
lower_load!(q, op, vectorized, W, unrolled, U, suffix, mask)
417-
elseif isstore(op)
418-
lower_store!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
419-
elseif iscompute(op)
420-
lower_compute!(q, op, vectorized, W, unrolled, U, suffix, mask)
421-
else
422-
lower_constant!(q, op, vectorized, W, unrolled, U, suffix, mask)
423-
end
424-
end
425405
function lower_constant!(
426406
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, U::Int,
427407
suffix::Union{Nothing,Int}, mask::Any = nothing
@@ -470,9 +450,23 @@ function lower_constant!(
470450
q::Expr, ops::AbstractVector{Operation}, vectorized::Symbol, W::Symbol, unrolled::Symbol, U::Int,
471451
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
472452
)
473-
foreach(op -> lower_constan!(q, op, vectorized, W, unrolled, U, suffix, mask), ops)
453+
foreach(op -> lower_constant!(q, op, vectorized, W, unrolled, U, suffix, mask), ops)
474454
end
475455

456+
function lower!(
457+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
458+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
459+
)
460+
if isconstant(op)
461+
lower_constant!(q, op, vectorized, W, unrolled, U, suffix, mask)
462+
elseif isload(op)
463+
lower_load!(q, op, vectorized, W, unrolled, U, suffix, mask)
464+
elseif iscompute(op)
465+
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask)
466+
else#if isstore(op)
467+
lower_store!(q, op, vectorized, W, unrolled, U, suffix, mask)
468+
end
469+
end
476470
function lower!(
477471
q::Expr, ops::AbstractVector{<:AbstractVector{Operation}}, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
478472
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
@@ -501,7 +495,6 @@ function lower_nest(
501495
nisvectorized = loopsym === vectorized
502496
nisunrolled = false
503497
nistiled = false
504-
# @show n, mask
505498
if istiled
506499
if n == nloops
507500
loopsym = tiledsym(loopsym)
@@ -514,7 +507,6 @@ function lower_nest(
514507
unrolled = last(order)
515508
nisunrolled = n == nloops
516509
end
517-
# @show unrolled, order
518510
blockq = Expr(:block)
519511
n == 1 || push!(blockq.args, Expr(:(=), order[n-1], loopstart))
520512
loopq = if exprtype === :block
@@ -581,7 +573,6 @@ function add_vec_rem_iter(
581573
loopq
582574
end
583575
function lower_set(ls::LoopSet, vectorized::Symbol, U::Int, T::Int, W::Symbol, ::Nothing, Uexprtype::Symbol)
584-
# @show U, T, W
585576
loopstart = 0
586577
istiled = T != -1
587578
order = names(ls)
@@ -620,11 +611,11 @@ function lower_set_unrolled_is_vectorized(ls::LoopSet, vectorized::Symbol, U::In
620611
loopq
621612
end
622613
function initialize_outer_reductions!(
623-
q::Expr, op::Operation, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, unrolled::Symbol, suffix::Union{Symbol,Nothing} = nothing
614+
q::Expr, op::Operation, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, vectorized::Symbol, suffix::Union{Symbol,Nothing} = nothing
624615
)
625616
# T = op.elementbytes == 8 ? :Float64 : :Float32
626617
z = Expr(:call, REDUCTION_ZERO[op.instruction], typeT)
627-
if unrolled reduceddependencies(op)
618+
if vectorized reduceddependencies(op)
628619
z = Expr(:call, lv(:vbroadcast), W, z)
629620
end
630621
mvar = variable_name(op, suffix)
@@ -633,16 +624,15 @@ function initialize_outer_reductions!(
633624
end
634625
nothing
635626
end
636-
function initialize_outer_reductions!(q::Expr, ls::LoopSet, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, unrolled::Symbol, suffix::Union{Symbol,Nothing} = nothing)
637-
foreach(or -> initialize_outer_reductions!(q, ls.operations[or], Umin, Umax, W, typeT, unrolled, suffix), ls.outer_reductions)
627+
function initialize_outer_reductions!(q::Expr, ls::LoopSet, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, vectorized::Symbol, suffix::Union{Symbol,Nothing} = nothing)
628+
foreach(or -> initialize_outer_reductions!(q, ls.operations[or], Umin, Umax, W, typeT, vectorized, suffix), ls.outer_reductions)
638629
end
639-
function initialize_outer_reductions!(ls::LoopSet, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, unrolled::Symbol, suffix::Union{Symbol,Nothing} = nothing)
640-
initialize_outer_reductions!(ls.preamble, ls, Umin, Umax, W, typeT, unrolled, suffix)
630+
function initialize_outer_reductions!(ls::LoopSet, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, vectorized::Symbol, suffix::Union{Symbol,Nothing} = nothing)
631+
initialize_outer_reductions!(ls.preamble, ls, Umin, Umax, W, typeT, vectorized, suffix)
641632
end
642-
function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::Int, W::Symbol, typeT::Symbol, unrolledloop::Loop)
643-
unrolled = unrolledloop.itersymbol
633+
function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::Int, W::Symbol, typeT::Symbol, unrolledloop::Loop, vectorized::Symbol)
644634
ifq = Expr(:block)
645-
initialize_outer_reductions!(ifq, ls, Ulow, Uhigh, W, typeT, unrolled)
635+
initialize_outer_reductions!(ifq, ls, Ulow, Uhigh, W, typeT, vectorized)
646636
push!(ifq.args, loopq)
647637
reduce_range!(ifq, ls, Ulow, Uhigh)
648638
comparison = if unrolledloop.hintexact
@@ -786,7 +776,7 @@ function lower_unrolled_dynamic!(
786776
if manageouterreductions
787777
# Umax = (!static_unroll && U > 2) ? U >> 1 : U
788778
Ureduct = U > 6 ? 4 : U
789-
initialize_outer_reductions!(q, ls, 0, Ureduct, W, typeT, last(names(ls)))
779+
initialize_outer_reductions!(q, ls, 0, Ureduct, W, typeT, vectorized)#last(names(ls)))
790780
else
791781
Ureduct = -1
792782
end
@@ -798,7 +788,7 @@ function lower_unrolled_dynamic!(
798788
if firstiter # first iter
799789
loopq = lower_set(ls, vectorized, Ut, T, W, nothing, Uexprtype)
800790
if T == -1 && manageouterreductions && U > 4
801-
loopq = add_upper_outer_reductions(ls, loopq, Ureduct, U, W, typeT, unrolledloop)
791+
loopq = add_upper_outer_reductions(ls, loopq, Ureduct, U, W, typeT, unrolledloop, vectorized)
802792
end
803793
push!(q.args, loopq)
804794
elseif U == 1 #
@@ -863,10 +853,11 @@ function definemask(loop::Loop, W::Symbol, allon::Bool)
863853
maskexpr(W, loop.rangesym, allon)
864854
end
865855
end
866-
function setup_Wmask!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, U::Int)
856+
function setup_Wmask!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, tiled::Symbol, U::Int)
867857
pushpreamble!(ls, Expr(:(=), typeT, determine_eltype(ls)))
868858
pushpreamble!(ls, Expr(:(=), W, determine_width(ls, typeT, unrolled)))
869859
pushpreamble!(ls, definemask(ls.loops[vectorized], W, U > 1 && unrolled === vectorized))
860+
# define_remaining_ops!( ls, vectorized, W, unrolled, tiled, U )
870861
end
871862
function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
872863
order = ls.loop_order.loopnames
@@ -875,11 +866,11 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
875866
mangledtiled = tiledsym(tiled)
876867
W = gensym(:W)
877868
typeT = gensym(:T)
878-
setup_Wmask!(ls, W, typeT, vectorized, unrolled, U)
869+
setup_Wmask!(ls, W, typeT, vectorized, unrolled, tiled, U)
879870
tiledloop = ls.loops[tiled]
880871
static_tile = tiledloop.hintexact
881872
unrolledloop = ls.loops[unrolled]
882-
initialize_outer_reductions!(ls, 0, 4, W, typeT, unrolled)
873+
initialize_outer_reductions!(ls, 0, 4, W, typeT, vectorized)#unrolled)
883874
q = Expr(:block, Expr(:(=), mangledtiled, 0))
884875
# we build up the loop expression.
885876
Trem = Tt = T
@@ -932,12 +923,11 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
932923
end
933924
function lower_unrolled(ls::LoopSet, vectorized::Symbol, U::Int)
934925
order = ls.loop_order.loopnames
935-
# @show order
936926
unrolled = last(order)
937927
# W = VectorizationBase.pick_vector_width(ls, unrolled)
938928
W = gensym(:W)
939929
typeT = gensym(:T)
940-
setup_Wmask!(ls, W, typeT, vectorized, unrolled, U)
930+
setup_Wmask!(ls, W, typeT, vectorized, unrolled, last(order), U)
941931
q = lower_unrolled!(Expr(:block, Expr(:(=), unrolled, 0)), ls, vectorized, U, -1, W, typeT, ls.loops[unrolled])
942932
Expr(:block, ls.preamble, q)
943933
end
@@ -950,11 +940,8 @@ end
950940
# Requires sorting
951941
function lower(ls::LoopSet)
952942
order, vectorized, U, T = choose_order(ls)
953-
# @show order, U, T
954-
# @show ls.loop_order.loopnames
955943
istiled = T != -1
956944
fillorder!(ls, order, istiled)
957-
# @show order, ls.loop_order.loopnames
958945
istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
959946
end
960947

0 commit comments

Comments
 (0)