Skip to content

Commit 704e23c

Browse files
committed
Slightly more robust handling of unrolling in inner vectorized reductions
1 parent e043283 commit 704e23c

File tree

12 files changed

+125
-62
lines changed

12 files changed

+125
-62
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ SLEEFPirates = "0.6.14"
2828
Static = "0.2"
2929
ThreadingUtilities = "0.4.1"
3030
UnPack = "1"
31-
VectorizationBase = "0.19.22"
31+
VectorizationBase = "0.19.28"
3232
julia = "1.5"
3333

3434
[extras]

src/LoopVectorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ using Base.Meta: isexpr
3434
using DocStringExtensions
3535
import LinearAlgebra # for check_args
3636

37-
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast, log_fast, log2_fast, log10_fast
37+
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast
38+
using SLEEFPirates: log_fast, log2_fast, log10_fast
3839

3940

4041
using ArrayInterface

src/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ function add_broadcast!(
368368
mergesetdiffv!(deps, loopdependencies(parent), reduceddependencies(parent))
369369
end
370370
op = Operation(
371-
length(operations(ls)), destname, elementbytes, instr, compute, deps, NOPARENTS, parents
371+
length(operations(ls)), destname, elementbytes, instr, compute, deps, NODEPENDENCY, parents
372372
)
373373
pushop!(ls, op, destname)
374374
end

src/codegen/lower_compute.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ function parent_op_name(
274274
opp = parents_op[n]
275275
parent = mangledvar(opp)
276276
u = 0
277-
if n == tiledouterreduction
277+
if n == tiledouterreduction# && isvectorized(opp)
278278
parent = Symbol(parent, modsuffix)
279279
else
280280
if parents_u₂syms[n]
@@ -452,7 +452,11 @@ function lower_compute!(
452452
# modsuffix = ls.unrollspecification[].u₁#getu₁full(ls, u₁)#u₁
453453
# Symbol(mangledvar(op), '_', modsuffix)
454454
# else
455-
modsuffix = 0#suffix % tiled_outerreduct_unroll(ls)
455+
if u₁unrolledsym
456+
modsuffix = 0
457+
else
458+
modsuffix = suffix % ls.ureduct[]
459+
end
456460
Symbol(mangledvar(op), modsuffix)
457461
# end
458462
# dopartialmap = u₁ > 1
@@ -483,7 +487,8 @@ function lower_compute!(
483487
add_loopvalue!(instrcall, loopval, ua, u₁)
484488
elseif name(opp) === name(op)
485489
selfdep = n
486-
if ((isvectorized(opp) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)) ||
490+
# @show mangledvar(op), name(opp), name(op)
491+
if ((isvectorized(opp) && !isvectorized(op))) ||
487492
(parents_u₁syms[n] != u₁unrolledsym) || (parents_u₂syms[n] != u₂unrolledsym)
488493

489494
selfopname, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
@@ -495,14 +500,26 @@ function lower_compute!(
495500
else
496501
push!(instrcall.args, varsym)
497502
end
498-
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (isouterreduction(ls, opp) != -1)
503+
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (parents_u₂syms[n] & (!u₂unrolledsym))
504+
# elseif parents_u₂syms[n] & (!u₂unrolledsym)
505+
#&& (isouterreduction(ls, opp) != -1)
499506
# this checks if the parent is u₂ unrolled but this operation is not, in which case we need to reduce it.
507+
# @show op opp
500508
reduced_u₂ = reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls))
501-
reduced_u₂ = reduce_parent!(q, ls, op, opp, reduced_u₂)
509+
reducedparentname = gensym!(ls, "reducedop")
510+
push!(q.args, Expr(:(=), reducedparentname, reduced_u₂))
511+
reduced_u₂ = reduce_parent!(q, ls, op, opp, reducedparentname)
502512
push!(instrcall.args, reduced_u₂)
503513
else
504514
parent, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
505515
parent = reduce_parent!(q, ls, op, opp, parent)
516+
# if instr.instr === :vfmadd_fast && tiledouterreduction > 0
517+
# @show mvar, varsym, selfopname
518+
# end
519+
# @show opp
520+
# if instr.instr === :identity
521+
# @show isvectorized(op) isvectorized(opp)
522+
# end
506523
if (selfdep == 0) && search_tree(parents(opp), name(op))
507524
selfdep = n
508525
push!(instrcall.args, parent)

src/codegen/lower_store.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ function lower_tiled_store!(
181181
end
182182

183183
function donot_tile_store(ls::LoopSet, op::Operation, vloop::Loop, reductfunc::Symbol, u₂::Int)
184-
(!((reductfunc === Symbol("")) && all(op.ref.loopedindex))) || (u₂ 1) || isconditionalmemop(op) && return true
184+
((!((reductfunc === Symbol("")) && all(op.ref.loopedindex))) || (u₂ 1) || isconditionalmemop(op)) && return true
185185
rejectcurly(op) && return true
186186
omop = offsetloadcollection(ls)
187187
batchid, opind = omop.batchedcollectionmap[identifier(op)]

src/codegen/lowering.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,9 @@ function initialize_outer_reductions!(
441441
@unpack u₁, u₂ = us
442442
Umax = u₂ == -1 ? _Umax : u₁
443443
reduct_zero = reduction_zero(op.instruction)
444-
isvectorized = vectorized reduceddependencies(op)
445444
typeTr = ELTYPESYMBOL
446445
u₁u, u₂u = isunrolled_sym(op, getloop(ls, us.u₁loopnum).itersymbol, getloop(ls, us.u₂loopnum).itersymbol, getloop(ls, us.vloopnum).itersymbol)#, u₂)
447-
z = if isvectorized
446+
z = if isvectorized(op)
448447
if Umax == 1 || !u₁u
449448
if reduct_zero === :zero
450449
Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, typeTr, rs)
@@ -792,25 +791,15 @@ function calc_Ureduct!(ls::LoopSet, us::UnrollSpecification)
792791
if u₁ui == -1
793792
u₁ui = Int(u₁u)
794793
u₂ui = Int(u₁u)
795-
else
796-
@assert (u₁ui == Int(u₁u)) & (u₂ui == Int(u₁u)) "Doesn't currenly andle differently unrolled reductions yet, please file an issue with an example."
794+
elseif !((u₁ui == Int(u₁u)) & (u₂ui == Int(u₁u)))
795+
throw(ArgumentError("Doesn't currenly handle differently unrolled reductions yet, please file an issue with an example."))
797796
end
798797
end
799798
if u₁ui % Bool
800799
u₁
801-
# push!(q.args, Expr(:(=), Symbol(mvar, '_', u₁), z))
802800
else
803801
u₂
804-
# for u ∈ 0:_Umax-1
805-
# # push!(q.args, Expr(:(=), Symbol(mvar, '_', u), z))
806-
# push!(q.args, Expr(:(=), Symbol(mvar, u), z))
807-
# end
808802
end
809-
# u₁loopnum == vloopnum
810-
# u₂
811-
# else
812-
# u₁
813-
# u₁#tiled_outerreduct_unroll(us)
814803
end
815804
ls.ureduct[] = ur
816805
end
@@ -935,6 +924,8 @@ function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop
935924
((u₂max > 1) | accesses_memory(op)) ? isunrolled_sym(op, u₁loop, u₂loop, vloop) : (isunrolled_sym(op, u₁loop), false)
936925
end
937926

927+
928+
938929
function variable_name(op::Operation, suffix::Int)
939930
mvar = mangledvar(op)
940931
suffix == -1 ? mvar : Symbol(mvar, suffix, :_)

src/codegen/operation_evaluation_order.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,12 @@ function isnopidentity(ls::LoopSet, op::Operation, u₁loop::Symbol, u₂loop::S
3434
# if (u₁unrolledsym == first(parents_u₁syms)) && (isu₂unrolled(op) == parents_u₂syms[1])
3535
opp = only(parents_op)
3636
if (isu₁unrolled(op) == isu₁unrolled(opp)) & (isu₂unrolled(op) == isu₂unrolled(opp))
37-
#TODO: identifer(first(parents_op)) ∉ ls.outer_reductions is going to miss a lot of cases
38-
#Should probably replace that with `DVec` (demoting Vec) types, that demote to scalar.
39-
#TODO: document (after finding out...) why only checking `isvectorized(first(parents_op))` -- why not `any(isvectorized, parents_op)`???
40-
if (isvectorized(opp) && !isvectorized(op)) && !dependent_outer_reducts(ls, op)
37+
true
38+
else
39+
if isvectorized(opp) & (!isvectorized(op))
4140
op.instruction = reduction_to_scalar(instruction(opp))
4241
op.mangledvariable = gensym(op.mangledvariable)
43-
false
44-
else
45-
true
4642
end
47-
else
4843
false
4944
end
5045
else

src/modeling/graphs.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ function cacheunrolled!(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, vloopsy
592592
end
593593
end
594594
for op operations(ls)
595-
setunrolled!(op, u₁loop, u₂loop, vloopsym)
595+
setunrolled!(ls, op, u₁loop, u₂loop, vloopsym)
596596
if accesses_memory(op)
597597
rc = rejectcurly(ls, op, u₁loop, vloopsym)
598598
op.rejectcurly = rc
@@ -606,6 +606,47 @@ function cacheunrolled!(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, vloopsy
606606
end
607607
end
608608
end
609+
function setunrolled!(ls::LoopSet, op::Operation, u₁loopsym, u₂loopsym, vectorized)
610+
op.u₁unrolled = u₁loopsym loopdependencies(op)
611+
op.u₂unrolled = u₂loopsym loopdependencies(op)
612+
op.vectorized = vectorized loopdependencies(op)
613+
if isconstant(op)
614+
u₁ = op.u₁unrolled
615+
u₂ = op.u₂unrolled
616+
v = op.vectorized
617+
for opp children(op)
618+
u₁ = u₁ && u₁loopsym loopdependencies(opp)
619+
u₂ = u₂ && u₂loopsym loopdependencies(opp)
620+
v = v && vectorized loopdependencies(opp)
621+
end
622+
if isouterreduction(ls, op) -1 && !all((u₁,u₂,v))
623+
opv = true
624+
for opp parents(op)
625+
if iscompute(opp) && instruction(opp).instr :identity
626+
opv = false
627+
break
628+
end
629+
end
630+
if opv
631+
if !u₁ && u₁loopsym reduceddependencies(op)
632+
u₁ = true
633+
end
634+
if !u₂ && u₂loopsym reduceddependencies(op)
635+
u₂ = true
636+
end
637+
if !v && vectorized reduceddependencies(op)
638+
v = true
639+
end
640+
end
641+
end
642+
op.u₁unrolled = u₁
643+
op.u₂unrolled = u₂
644+
op.vectorized = v
645+
end
646+
# op.u₁unrolled, op.u₂unrolled = isunrolled_sym(op, u₁loopsym, u₂loopsym, vectorized)
647+
nothing
648+
end
649+
609650
rejectcurly(op::Operation) = op.rejectcurly
610651
rejectinterleave(op::Operation) = op.rejectinterleave
611652
num_loops(ls::LoopSet) = length(ls.loops)

src/modeling/operations.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -200,21 +200,19 @@ mutable struct Operation <: AbstractLoopOperation
200200
rejectinterleave::Bool
201201
function Operation(
202202
identifier::Int,
203-
variable,
204-
elementbytes,
205-
instruction,
206-
node_type,
207-
dependencies = Symbol[],
208-
reduced_deps = Symbol[],
209-
parents = Operation[],
203+
variable::Symbol,
204+
elementbytes::Int,
205+
instruction::Union{Symbol,Instruction},
206+
node_type::OperationType,
207+
dependencies::Vector{Symbol} = Symbol[],
208+
reduced_deps::Vector{Symbol} = Symbol[],
209+
parents::Vector{Operation} = Operation[],
210210
ref::ArrayReferenceMeta = NOTAREFERENCE,
211-
reduced_children = Symbol[]
211+
reduced_children::Vector{Symbol} = Symbol[]
212212
)
213213
new(
214214
identifier, variable, elementbytes, instruction, node_type,
215-
convert(Vector{Symbol},dependencies),
216-
convert(Vector{Symbol},reduced_deps),
217-
convert(Vector{Operation},parents), Operation[],
215+
dependencies, reduced_deps, parents, Operation[],
218216
ref, Symbol("##", variable, :_),
219217
reduced_children
220218
)
@@ -224,19 +222,6 @@ end
224222
isu₁unrolled(op::Operation) = op.u₁unrolled
225223
isu₂unrolled(op::Operation) = op.u₂unrolled
226224
isvectorized(op::Operation) = op.vectorized
227-
function setunrolled!(op::Operation, u₁loopsym, u₂loopsym, vectorized)
228-
op.u₁unrolled = u₁loopsym loopdependencies(op)
229-
op.u₂unrolled = u₂loopsym loopdependencies(op)
230-
op.vectorized = vectorized loopdependencies(op)
231-
if isconstant(op)
232-
for opp children(op)
233-
op.u₁unrolled = op.u₁unrolled && u₁loopsym loopdependencies(opp)
234-
op.u₂unrolled = op.u₂unrolled && u₂loopsym loopdependencies(opp)
235-
op.vectorized = op.vectorized && vectorized loopdependencies(opp)
236-
end
237-
end
238-
nothing
239-
end
240225

241226
function matches(op1::Operation, op2::Operation)
242227
op1 === op2 && return true

src/parse/add_stores.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function add_conditional_store!(ls::LoopSet, LHS, condop::Operation, storeop::Op
123123
# return cse_store!(ls, op)
124124
# end
125125
# end
126-
op = Operation( id, name(mref), elementbytes, :conditionalstore!, memstore, ldref, NODEPENDENCY, storeparents, mref )
126+
op = Operation( id, name(mref), elementbytes, :conditionalstore!, memstore, ldref, reduceddependencies(storeop), storeparents, mref )
127127
add_unique_store!(ls, op)
128128
end
129129

0 commit comments

Comments
 (0)