Skip to content

Commit 5587223

Browse files
committed
Move operations after inner loop when possible.
1 parent 6c2de2b commit 5587223

File tree

2 files changed

+63
-27
lines changed

2 files changed

+63
-27
lines changed

src/determinestrategy.jl

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ function solve_tilesize(
270270
solve_tilesize(cost_vec, reg_pressure, maxU, maxT)
271271
end
272272

273+
function set_for_each_parent!(adal::Vector{T}, op::Operation, val::T) where {T}
274+
@inbounds for opp parents(op)
275+
adal[identifier(opp)] = val
276+
end
277+
end
278+
273279
# Just tile outer two loops?
274280
# But optimal order within tile must still be determined
275281
# as well as size of the tiles.
@@ -280,7 +286,13 @@ function evaluate_cost_tile(
280286
@assert N 2 "Cannot tile merely $N loops!"
281287
tiled = order[1]
282288
unrolled = order[2]
283-
included_vars = fill(false, length(operations(ls)))
289+
ops = operations(ls)
290+
nops = length(ops)
291+
included_vars = fill(false, nops)
292+
unrolledtiled = fill(false, 2, nops)
293+
descendentsininnerloop = fill(false, nops)
294+
innerloop = last(order)
295+
iters = fill(-99.9, nops)
284296
nested_loop_syms = Symbol[]# Set{Symbol}()
285297
iter = 1.0
286298
# Need to check if fusion is possible
@@ -306,7 +318,7 @@ function evaluate_cost_tile(
306318
iter *= Float64(length(ls, itersym))
307319
end
308320
# check which vars we can define at this level of loop nest
309-
for (id, op) enumerate(operations(ls))
321+
for (id, op) enumerate(ops)
310322
# isconstant(op) && continue
311323
# @assert id == identifier(op)+1 # testing, for now
312324
# won't define if already defined...
@@ -318,27 +330,37 @@ function evaluate_cost_tile(
318330
rd = reduceddependencies(op)
319331
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return 0,0,Inf
320332
included_vars[id] = true
321-
rt, lat, rp = cost(op, vectorized, Wshift, size_T)
333+
unrolledtiled[1,id] = unrolled loopdependencies(op)
334+
unrolledtiled[2,id] = tiled loopdependencies(op)
335+
iters[id] = iter
336+
innerloop loopdependencies(op) && set_for_each_parent!(descendentsininnerloop, op, true)
337+
end
338+
end
339+
for (id, op) enumerate(ops)
340+
iters[id] == -99.9 && continue
341+
descendentsininnerloop[id] || continue
342+
isunrolled = unrolledtiled[1,id]
343+
istiled = unrolledtiled[2,id]
344+
rt, lat, rp = cost(op, vectorized, Wshift, size_T)
322345
# @show instruction(op), rt, lat, rp, iter
323-
rt *= iter
324-
isunrolled = unrolled loopdependencies(op)
325-
istiled = tiled loopdependencies(op)
346+
rt *= iters[id]
326347
# @show isunrolled, istiled
327-
if isunrolled && istiled # no cost decrease; cost must be repeated
328-
cost_vec[1] += rt
329-
reg_pressure[1] += rp
330-
elseif isunrolled # cost decreased by tiling
331-
cost_vec[2] += rt
332-
reg_pressure[2] += rp
333-
elseif istiled # cost decreased by unrolling
334-
cost_vec[3] += rt
335-
reg_pressure[3] += rp
336-
else# neither unrolled or tiled
337-
cost_vec[4] += rt
338-
reg_pressure[4] += rp
339-
end
348+
if isunrolled && istiled # no cost decrease; cost must be repeated
349+
cost_vec[1] += rt
350+
reg_pressure[1] += rp
351+
elseif isunrolled # cost decreased by tiling
352+
cost_vec[2] += rt
353+
reg_pressure[2] += rp
354+
elseif istiled # cost decreased by unrolling
355+
cost_vec[3] += rt
356+
reg_pressure[3] += rp
357+
else# neither unrolled or tiled
358+
cost_vec[4] += rt
359+
reg_pressure[4] += rp
340360
end
341361
end
362+
# @show order, vectorized cost_vec reg_pressure
363+
# @show solve_tilesize(ls, unrolled, tiled, cost_vec, reg_pressure)
342364
solve_tilesize(ls, unrolled, tiled, cost_vec, reg_pressure)
343365
end
344366

src/graphs.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -603,15 +603,17 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
603603
end
604604
end
605605

606-
function place_after_loop(op::Operation)
607-
if isload(op) || length(reduceddependencies(op)) == 0
606+
function place_after_loop!(adal::Vector{Bool}, op::Operation)
607+
pal = if isload(op) || length(reduceddependencies(op)) == 0
608608
1
609609
elseif length(reduceddependencies(op)) > 1
610610
2
611611
else
612612
rd = first(reduceddependencies(op))
613613
any(d -> d === rd, loopdependencies(op)) ? 1 : 2
614614
end
615+
pal == 1 && set_for_each_parent!(adal, op, false)
616+
pal
615617
end
616618

617619
function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
@@ -620,7 +622,6 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
620622
# @show 1, ro, order
621623
# copyto!(ro, order)
622624
# @show 2, ro, order
623-
empty!(lo)
624625
nloops = length(order)
625626
if loopistiled
626627
tiled = order[1]
@@ -629,23 +630,36 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
629630
tiled = Symbol("##UNDEFINED##")
630631
unrolled = first(order)
631632
end
632-
included_vars = fill(false, length(operations(ls)))
633+
ops = operations(ls)
634+
nops = length(ops)
635+
included_vars = fill(false, nops)
636+
all_descendents_after_loop = fill(true, nops)
637+
positions = fill((-1,-1,-1,-1,-1), nops)#Vector{NTuple{5,Int}}(undef, nops)
633638
# to go inside out, we just have to include all those not-yet included depending on the current sym
634639
for _n 1:nloops
635640
n = 1 + nloops - _n
636641
ro[_n] = loopsym = order[n]
637642
#loopsym = order[n]
638-
for (id,op) enumerate(operations(ls))
643+
for (id,op) enumerate(ops)
639644
included_vars[id] && continue
640645
loopsym loopdependencies(op) || continue
641646
included_vars[id] = true
642647
isunrolled = (unrolled loopdependencies(op)) + 1
643648
istiled = (loopistiled ? (tiled loopdependencies(op)) : false) + 1
644649
optype = Int(op.node_type) + 1
645-
after_loop = place_after_loop(op)
646-
push!(lo[optype,isunrolled,istiled,after_loop,_n], op)
650+
after_loop = place_after_loop!(all_descendents_after_loop, op)
651+
positions[id] = (optype,isunrolled,istiled,after_loop,_n)
647652
end
648-
end
653+
end
654+
empty!(lo)
655+
for id 1:nops
656+
optype,isunrolled,istiled,after_loop,_n = positions[id]
657+
optype == -1 && continue#@show ops[id]
658+
if all_descendents_after_loop[id]
659+
after_loop = 2
660+
end
661+
push!(lo[optype,isunrolled,istiled,after_loop,_n], ops[id])
662+
end
649663
# 3, ro, order
650664
end
651665

0 commit comments

Comments
 (0)