Skip to content

Commit a0f3448

Browse files
committed
Minor updates; still WIP on adding lowering of unrolled loops.
1 parent aa61df8 commit a0f3448

File tree

2 files changed

+108
-22
lines changed

2 files changed

+108
-22
lines changed

src/costs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REG
6161
# hand, should indicate how many registers we're keeping live for the sake of eventually storing.
6262
const COST = Dict{Symbol,InstructionCost}(
6363
:getindex => InstructionCost(3,0.5,-3.0,0),
64-
:setindex! => InstructionCost(3,1.0,-3.0,1),
64+
:setindex! => InstructionCost(3,1.0,-3.0,1),
65+
:zero => InstructionCost(1,0.5),
66+
:one => InstructionCost(3,0.5),
6567
:(+) => InstructionCost(4,0.5),
6668
:(-) => InstructionCost(4,0.5),
6769
:(*) => InstructionCost(4,0.5),

src/graphs.jl

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,22 @@ isdense(::Type{<:DenseArray}) = true
2929
@enum NodeType begin
3030
memload
3131
memstore
32-
compute
32+
compute_new
33+
compute_update
34+
# accumulator
3335
end
3436

37+
# const ID = Threads.Atomic{UInt}(0)
3538

39+
"""
40+
if node_type == memstore || node_type == compute_new || node_type == compute_store
41+
symbolic metadata contains info on direct dependencies / placement within loop.
42+
43+
44+
"""
3645
struct Operation
37-
identifier::Symbol
46+
identifier::UInt
47+
variable::Symbol
3848
elementbytes::Int
3949
instruction::Symbol
4050
node_type::NodeType
@@ -45,9 +55,16 @@ struct Operation
4555
children::Vector{Operation}
4656
numerical_metadata::Vector{Int}
4757
symbolic_metadata::Vector{Symbol}
48-
function Operation(elementbytes, instruction, node_type, identifier = gensym())
58+
function Operation(
59+
elementbytes,
60+
instruction,
61+
node_type,
62+
identifier,
63+
variable = gensym()
64+
)
65+
# identifier = Threads.atomic_add!(ID, one(UInt))
4966
new(
50-
identifier, elementbytes, instruction, node_type,
67+
identifier, variable, elementbytes, instruction, node_type,
5168
Set{Symbol}(), Operation[], Operation[], Int[], Symbol[]
5269
)
5370
end
@@ -65,6 +82,7 @@ parents(op::Operation) = op.parents
6582
children(op::Operation) = op.children
6683
loopdependencies(op::Operation) = op.dependencies
6784
identifier(op::Operation) = op.identifier
85+
name(op::Operation) = op.variable
6886
instruction(op::Operation) = op.instruction
6987

7088
function stride(op::Operation, sym::Symbol)
@@ -98,10 +116,11 @@ struct LoopSet
98116
loadops::Vector{Operation} # Split them to make it easier to iterate over just a subset
99117
computeops::Vector{Operation}
100118
storeops::Vector{Operation}
101-
119+
reductions::Set{Operation}
102120
end
103121
num_loops(ls::LoopSet) = length(ls.loops)
104122
isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
123+
itersyms(ls::LoopSet) = keys(ls.loops)
105124
function looprange(ls::LoopSet, s::Symbol)
106125
loop = ls.loops[s]
107126
Expr(:(:), 0, loop.hintexact ? loop.rangehint - 1 : Expr(:call, :(-), loop.rangesym, 1))
@@ -155,7 +174,8 @@ end
155174
function evaluate_cost_unroll(
156175
ls::LoopSet, order::Vector{Symbol}, max_cost = typemax(Float64), unrolled::Symbol = first(order)
157176
)
158-
included_vars = Set{Symbol}()
177+
# included_vars = Set{UInt}()
178+
included_vars = fill(false, length(operations(ls)))
159179
nested_loop_syms = Set{Symbol}()
160180
total_cost = 0.0
161181
iter = 1.0
@@ -174,10 +194,10 @@ function evaluate_cost_unroll(
174194
for op operations(ls)
175195
# won't define if already defined...
176196
id = identifier(op)
177-
id included_vars && continue
197+
included_vars[id] && continue
178198
# it must also be a subset of defined symbols
179199
loopdependencies(op) nested_loop_syms || continue
180-
push!(included_vars, id)
200+
included_vars[id] = true
181201

182202
total_cost += iter * first(cost(op, unrolled, Wshift, size_T))
183203
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
@@ -188,12 +208,12 @@ end
188208

189209
# only covers unrolled ops; everything else considered lifted?
190210
function depchain_cost!(
191-
skip::Set{Symbol}, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int, sl::Int = 0, rt::Float64 = 0.0
211+
skip::Vector{Bool}, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int, sl::Int = 0, rt::Float64 = 0.0
192212
)
193-
push!(skip, sym(op))
213+
skip[identifier(op)] = true
194214
# depth first search
195215
for opp parents(op)
196-
opp skip && continue
216+
skip[identifier(opp)] && continue
197217
sl, rt = depchain_cost!(skip, opp, unrolled, Wshift, size_T, sl, rt)
198218
end
199219
# Basically assuming memory and compute don't conflict, but everything else does
@@ -220,7 +240,7 @@ function determine_unroll_factor(
220240
# We also make sure register pressure is not too high.
221241
latency = 0
222242
recip_throughput = 0.0
223-
visited_nodes = Set{Symbol}()
243+
visited_nodes = fill(false, length(operations(ls)))
224244
for op operations(ls)
225245
if isreduction(op) && dependson(op, unrolled)
226246
sl, rt = depchain_cost!(visited_nodes, instruction(op), unrolled, Wshift, size_T)
@@ -283,7 +303,7 @@ function evaluate_cost_tile(
283303
@assert N 2 "Cannot tile merely $N loops!"
284304
tiled = order[1]
285305
unrolled = order[2]
286-
included_vars = Set{Symbol}()
306+
included_vars = fill(false, length(operations(ls)))
287307
nested_loop_syms = Set{Symbol}()
288308
iter = 1.0
289309
# Need to check if fusion is possible
@@ -308,12 +328,13 @@ function evaluate_cost_tile(
308328
end
309329
iter *= liter
310330
# check which vars we can define at this level of loop nest
311-
for op operations(ls)
331+
for (id, op) enumerate(operations(ls))
332+
@assert id == identifier(op) # testing, for now
312333
# won't define if already defined...
313-
sym(op) included_vars && continue
334+
included_vars[id] && continue
314335
# it must also be a subset of defined symbols
315336
loopdependencies(op) nested_loop_syms || continue
316-
push!(included_vars, sym(op))
337+
included_vars[id] = true
317338
rt, lat, rp = cost(op, unrolled, Wshift, size_T)
318339
rt *= iter
319340
isunrolled = unrolled loopdependencies(op)
@@ -428,15 +449,78 @@ function choose_order(ls::LoopSet)
428449
end
429450
end
430451

452+
function depends_on_assigned(op::Operation, assigned::Vector{Bool})
453+
for p parents(op)
454+
assigned[identifier(op)] && return true
455+
depends_on_assigned(p, assigned) && return true
456+
end
457+
false
458+
end
459+
431460
# construction requires ops inserted into operations vectors in dependency order.
432461
function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
462+
if isstaticloop(ls, first(order))
463+
lower_unroll_static(ls, order, U)
464+
else
465+
lower_unroll_dynamic(ls, order, U)
466+
end
467+
end
468+
function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
469+
# this function create the inner block
470+
args = Any[]
471+
nloops = length(order)
472+
unrolled = first(order)
473+
# included_syms = Set( (unrolled,) )
474+
included_vars = fill(false, length(operations(ls)))
475+
# to go inside out, we just have to include all those not-yet included depending on the current sym
476+
477+
n = 0
478+
loopsym = last(order)
479+
for (id,op) enumerate(operations(ls))
480+
# We add an op the first time all loop dependencies are met
481+
# when working through loops backwords, that equates to the first time we encounter a loop dependency
482+
loopsym dependencies(op) || continue
483+
included_vars[id] = true
484+
485+
486+
end
487+
for n 1:nloops - 2
488+
loopsym = order[nloops - n]
489+
blockq = Expr(:block, )
490+
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
491+
for (id,op) enumerate(operations(ls))
492+
included_vars[id] && continue
493+
# We add an op the first time all loop dependencies are met
494+
# when working through loops backwords, that equates to the first time we encounter a loop dependency
495+
loopsym dependencies(op) || continue
496+
included_vars[id] = true
497+
498+
after_loop = depends_on_assigned(op, included_vars)
499+
500+
501+
end
502+
end
503+
end
504+
function lower_unroll_static(ls::LoopSet, order::Vector{Symbol}, U::Int)
505+
506+
end
507+
function lower_unroll_dynamic(ls::LoopSet, order::Vector{Symbol}, U::Int)
433508
nested_loop_syms = Set{Symbol}()
434-
included_vars = Set{Symbol}()
435-
q = quote end
509+
# included_vars = Set{UInt}()
510+
included_vars = fill(false, length(operations(ls)))
511+
q = quote end #Expr(:block,)
512+
# rely on compiler to simplify integer indices
513+
for s itersyms(ls)
514+
push!(q.args, Expr(:(=), s, 0))
515+
end
436516
lastqargs = q.args
437517
postloop_reduction = false
438-
for itersym order
439-
# Add to set of defined symbles
518+
num_loops = length(order)
519+
unrolled = first(order)
520+
521+
for n 2:num_loops
522+
itersym = order[n]
523+
# Add to set of defined symbols
440524
push!(nested_loop_syms, itersym)
441525
# check which vars we can define at this level of loop nest
442526
if itersym === first(order)
@@ -445,7 +529,7 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
445529
loopq = looprange(ls::LoopSet, s::Symbol)
446530
end
447531
blockq = Expr(:block, )
448-
loopq = Expr(:for, Expr(:(=), itersym, looprange), )
532+
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
449533
for op operations(ls)
450534
# won't define if already defined...
451535
id = identifier(op)

0 commit comments

Comments
 (0)