Skip to content

Commit 80dad59

Browse files
committed
Switch lowering for twice-unrolled outer-reductions to use s_mod(u1 + U1*u2,4) rather than s_u2).
1 parent b308bb0 commit 80dad59

File tree

3 files changed

+19
-22
lines changed

3 files changed

+19
-22
lines changed

src/lower_compute.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,27 @@ function lower_compute!(
3939
parentsunrolled[i] || continue
4040
parentsunrolled[i] = false
4141
parentop = parents_op[i]
42+
i == tiledouterreduction && isconstant(parentop) && continue
4243
newparentop = Operation(
4344
parentop.identifier, gensym(parentop.variable), parentop.elementbytes, parentop.instruction, parentop.node_type,
4445
parentop.dependencies, parentop.reduced_deps, parentop.parents, parentop.ref, parentop.reduced_children
4546
)
4647
parentname = mangledvar(parentop)
4748
newparentname = mangledvar(newparentop)
4849
parents_op[i] = newparentop
49-
if i == tiledouterreduction && isconstant(newparentop)
50-
push!(q.args, Expr(:(=), Symbol(newparentname, suffix), Symbol(parentname, suffix)))
50+
if parentstiled[i]
51+
parentname = Symbol(parentname, suffix_)
52+
newparentname = Symbol(newparentname, suffix_)
53+
end
54+
if isconstant(newparentop)
55+
# @show i, parentstiled[i], newparentname, parentname
56+
push!(q.args, Expr(:(=), newparentname, Symbol(parentname, 0)))
5157
else
52-
if parentstiled[i]
53-
parentname = Symbol(parentname, suffix_)
54-
newparentname = Symbol(newparentname, suffix_)
55-
end
56-
if isconstant(newparentop)
57-
# @show i, parentstiled[i], newparentname, parentname
58-
push!(q.args, Expr(:(=), newparentname, Symbol(parentname, 0)))
59-
else
60-
for u 0:U-1
61-
push!(q.args, Expr(:(=), Symbol(newparentname, u), Symbol(parentname, u)))
62-
end
63-
reduce_expr!(q, newparentname, Instruction(reduction_to_single_vector(instruction(newparentop))), U)
64-
push!(q.args, Expr(:(=), newparentname, Symbol(newparentname, 0)))
58+
for u 0:U-1
59+
push!(q.args, Expr(:(=), Symbol(newparentname, u), Symbol(parentname, u)))
6560
end
61+
reduce_expr!(q, newparentname, Instruction(reduction_to_single_vector(instruction(newparentop))), U)
62+
push!(q.args, Expr(:(=), newparentname, Symbol(newparentname, 0)))
6663
end
6764
end
6865
end
@@ -96,8 +93,8 @@ function lower_compute!(
9693
for u 0:Uiter
9794
instrcall = Expr(instr) # Expr(:call, instr)
9895
varsym = if tiledouterreduction > 0 # then suffix !== nothing
99-
# modsuffix = ((u + suffix*U) & 3)
100-
modsuffix = suffix # (suffix & 3)
96+
modsuffix = ((u + suffix*U) & 3)
97+
# modsuffix = suffix # (suffix & 3)
10198
Symbol(mvar, modsuffix)
10299
elseif unrollsym
103100
Symbol(mvar, u)

src/lower_store.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ variable_name(op::Operation, ::Nothing) = mangledvar(op)
4343
variable_name(op::Operation, suffix) = Symbol(mangledvar(op), suffix, :_)
4444
# variable_name(op::Operation, suffix, u::Int) = (n = variable_name(op, suffix); u < 0 ? n : Symbol(n, u))
4545
function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, Uh2::Int)
46-
for u 0:Uh-1
47-
tru = Symbol(toreduct, u)
48-
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol(toreduct, u + Uh))))
46+
for u Uh:Uh2-1
47+
tru = Symbol(toreduct, u - Uh)
48+
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol(toreduct, u))))
4949
end
5050
for u 2Uh:Uh2-1
5151
tru = Symbol(toreduct, u - 2Uh)

src/lowering.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
166166
loopisstatic = isstaticloop(loop) & (!nisvectorized)
167167

168168
remmask = inclmask | nisvectorized
169-
Ureduct = (n == num_loops(ls)) ? calc_Ureduct(ls, us) : -1
169+
Ureduct = (n == num_loops(ls) && (u₂ == -1)) ? calc_Ureduct(ls, us) : -1
170170
sl = startloop(loop, nisvectorized, ls.W, loopsym)
171171
tc = terminatecondition(loop, us, n, ls.W, loopsym, inclmask, UF)
172172
body = lower_block(ls, us, n, inclmask, UF)
@@ -373,7 +373,7 @@ function calc_Ureduct(ls::LoopSet, us::UnrollSpecification)
373373
elseif num_loops(ls) == u₁loopnum
374374
min(u₁, 4)
375375
else
376-
u₂ == -1 ? u₁ : u₂
376+
u₂ == -1 ? u₁ : 4#u₂
377377
end
378378
end
379379
function lower(ls::LoopSet, us::UnrollSpecification)

0 commit comments

Comments
 (0)