@@ -29,12 +29,22 @@ isdense(::Type{<:DenseArray}) = true
29
29
@enum NodeType begin
30
30
memload
31
31
memstore
32
- compute
32
+ compute_new
33
+ compute_update
34
+ # accumulator
33
35
end
34
36
37
+ # const ID = Threads.Atomic{UInt}(0)
35
38
39
+ """
40
+ if node_type == memstore || node_type == compute_new || node_type == compute_store
41
+ symbolic metadata contains info on direct dependencies / placement within loop.
42
+
43
+
44
+ """
36
45
struct Operation
37
- identifier:: Symbol
46
+ identifier:: UInt
47
+ variable:: Symbol
38
48
elementbytes:: Int
39
49
instruction:: Symbol
40
50
node_type:: NodeType
@@ -45,9 +55,16 @@ struct Operation
45
55
children:: Vector{Operation}
46
56
numerical_metadata:: Vector{Int}
47
57
symbolic_metadata:: Vector{Symbol}
48
- function Operation (elementbytes, instruction, node_type, identifier = gensym ())
58
+ function Operation (
59
+ elementbytes,
60
+ instruction,
61
+ node_type,
62
+ identifier,
63
+ variable = gensym ()
64
+ )
65
+ # identifier = Threads.atomic_add!(ID, one(UInt))
49
66
new (
50
- identifier, elementbytes, instruction, node_type,
67
+ identifier, variable, elementbytes, instruction, node_type,
51
68
Set {Symbol} (), Operation[], Operation[], Int[], Symbol[]
52
69
)
53
70
end
@@ -65,6 +82,7 @@ parents(op::Operation) = op.parents
65
82
children (op:: Operation ) = op. children
66
83
loopdependencies (op:: Operation ) = op. dependencies
67
84
identifier (op:: Operation ) = op. identifier
85
+ name (op:: Operation ) = op. variable
68
86
instruction (op:: Operation ) = op. instruction
69
87
70
88
function stride (op:: Operation , sym:: Symbol )
@@ -98,10 +116,11 @@ struct LoopSet
98
116
loadops:: Vector{Operation} # Split them to make it easier to iterate over just a subset
99
117
computeops:: Vector{Operation}
100
118
storeops:: Vector{Operation}
101
-
119
+ reductions :: Set{Operation}
102
120
end
103
121
num_loops (ls:: LoopSet ) = length (ls. loops)
104
122
isstaticloop (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. hintexact
123
+ itersyms (ls:: LoopSet ) = keys (ls. loops)
105
124
function looprange (ls:: LoopSet , s:: Symbol )
106
125
loop = ls. loops[s]
107
126
Expr (:(:), 0 , loop. hintexact ? loop. rangehint - 1 : Expr (:call , :(- ), loop. rangesym, 1 ))
155
174
function evaluate_cost_unroll (
156
175
ls:: LoopSet , order:: Vector{Symbol} , max_cost = typemax (Float64), unrolled:: Symbol = first (order)
157
176
)
158
- included_vars = Set {Symbol} ()
177
+ # included_vars = Set{UInt}()
178
+ included_vars = fill (false , length (operations (ls)))
159
179
nested_loop_syms = Set {Symbol} ()
160
180
total_cost = 0.0
161
181
iter = 1.0
@@ -174,10 +194,10 @@ function evaluate_cost_unroll(
174
194
for op ∈ operations (ls)
175
195
# won't define if already defined...
176
196
id = identifier (op)
177
- id ∈ included_vars && continue
197
+ included_vars[id] && continue
178
198
# it must also be a subset of defined symbols
179
199
loopdependencies (op) ⊆ nested_loop_syms || continue
180
- push! ( included_vars, id)
200
+ included_vars[id] = true
181
201
182
202
total_cost += iter * first (cost (op, unrolled, Wshift, size_T))
183
203
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
@@ -188,12 +208,12 @@ end
188
208
189
209
# only covers unrolled ops; everything else considered lifted?
190
210
function depchain_cost! (
191
- skip:: Set{Symbol } , op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int , sl:: Int = 0 , rt:: Float64 = 0.0
211
+ skip:: Vector{Bool } , op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int , sl:: Int = 0 , rt:: Float64 = 0.0
192
212
)
193
- push! ( skip, sym (op))
213
+ skip[ identifier (op)] = true
194
214
# depth first search
195
215
for opp ∈ parents (op)
196
- opp ∈ skip && continue
216
+ skip[ identifier (opp)] && continue
197
217
sl, rt = depchain_cost! (skip, opp, unrolled, Wshift, size_T, sl, rt)
198
218
end
199
219
# Basically assuming memory and compute don't conflict, but everything else does
@@ -220,7 +240,7 @@ function determine_unroll_factor(
220
240
# We also make sure register pressure is not too high.
221
241
latency = 0
222
242
recip_throughput = 0.0
223
- visited_nodes = Set {Symbol} ( )
243
+ visited_nodes = fill ( false , length ( operations (ls)) )
224
244
for op ∈ operations (ls)
225
245
if isreduction (op) && dependson (op, unrolled)
226
246
sl, rt = depchain_cost! (visited_nodes, instruction (op), unrolled, Wshift, size_T)
@@ -283,7 +303,7 @@ function evaluate_cost_tile(
283
303
@assert N ≥ 2 " Cannot tile merely $N loops!"
284
304
tiled = order[1 ]
285
305
unrolled = order[2 ]
286
- included_vars = Set {Symbol} ( )
306
+ included_vars = fill ( false , length ( operations (ls)) )
287
307
nested_loop_syms = Set {Symbol} ()
288
308
iter = 1.0
289
309
# Need to check if fusion is possible
@@ -308,12 +328,13 @@ function evaluate_cost_tile(
308
328
end
309
329
iter *= liter
310
330
# check which vars we can define at this level of loop nest
311
- for op ∈ operations (ls)
331
+ for (id, op) ∈ enumerate (operations (ls))
332
+ @assert id == identifier (op) # testing, for now
312
333
# won't define if already defined...
313
- sym (op) ∈ included_vars && continue
334
+ included_vars[id] && continue
314
335
# it must also be a subset of defined symbols
315
336
loopdependencies (op) ⊆ nested_loop_syms || continue
316
- push! ( included_vars, sym (op))
337
+ included_vars[id] = true
317
338
rt, lat, rp = cost (op, unrolled, Wshift, size_T)
318
339
rt *= iter
319
340
isunrolled = unrolled ∈ loopdependencies (op)
@@ -428,15 +449,78 @@ function choose_order(ls::LoopSet)
428
449
end
429
450
end
430
451
452
+ function depends_on_assigned (op:: Operation , assigned:: Vector{Bool} )
453
+ for p ∈ parents (op)
454
+ assigned[identifier (op)] && return true
455
+ depends_on_assigned (p, assigned) && return true
456
+ end
457
+ false
458
+ end
459
+
431
460
# construction requires ops inserted into operations vectors in dependency order.
432
461
function lower_unroll (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
462
+ if isstaticloop (ls, first (order))
463
+ lower_unroll_static (ls, order, U)
464
+ else
465
+ lower_unroll_dynamic (ls, order, U)
466
+ end
467
+ end
468
+ function lower_unroll_inner_block (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
469
+ # this function create the inner block
470
+ args = Any[]
471
+ nloops = length (order)
472
+ unrolled = first (order)
473
+ # included_syms = Set( (unrolled,) )
474
+ included_vars = fill (false , length (operations (ls)))
475
+ # to go inside out, we just have to include all those not-yet included depending on the current sym
476
+
477
+ n = 0
478
+ loopsym = last (order)
479
+ for (id,op) ∈ enumerate (operations (ls))
480
+ # We add an op the first time all loop dependencies are met
481
+ # when working through loops backwords, that equates to the first time we encounter a loop dependency
482
+ loopsym ∈ dependencies (op) || continue
483
+ included_vars[id] = true
484
+
485
+
486
+ end
487
+ for n ∈ 1 : nloops - 2
488
+ loopsym = order[nloops - n]
489
+ blockq = Expr (:block , )
490
+ loopq = Expr (:for , Expr (:(= ), itersym, looprange), blockq)
491
+ for (id,op) ∈ enumerate (operations (ls))
492
+ included_vars[id] && continue
493
+ # We add an op the first time all loop dependencies are met
494
+ # when working through loops backwords, that equates to the first time we encounter a loop dependency
495
+ loopsym ∈ dependencies (op) || continue
496
+ included_vars[id] = true
497
+
498
+ after_loop = depends_on_assigned (op, included_vars)
499
+
500
+
501
+ end
502
+ end
503
+ end
504
+ function lower_unroll_static (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
505
+
506
+ end
507
+ function lower_unroll_dynamic (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
433
508
nested_loop_syms = Set {Symbol} ()
434
- included_vars = Set {Symbol} ()
435
- q = quote end
509
+ # included_vars = Set{UInt}()
510
+ included_vars = fill (false , length (operations (ls)))
511
+ q = quote end # Expr(:block,)
512
+ # rely on compiler to simplify integer indices
513
+ for s ∈ itersyms (ls)
514
+ push! (q. args, Expr (:(= ), s, 0 ))
515
+ end
436
516
lastqargs = q. args
437
517
postloop_reduction = false
438
- for itersym ∈ order
439
- # Add to set of defined symbles
518
+ num_loops = length (order)
519
+ unrolled = first (order)
520
+
521
+ for n ∈ 2 : num_loops
522
+ itersym = order[n]
523
+ # Add to set of defined symbols
440
524
push! (nested_loop_syms, itersym)
441
525
# check which vars we can define at this level of loop nest
442
526
if itersym === first (order)
@@ -445,7 +529,7 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
445
529
loopq = looprange (ls:: LoopSet , s:: Symbol )
446
530
end
447
531
blockq = Expr (:block , )
448
- loopq = Expr (:for , Expr (:(= ), itersym, looprange), )
532
+ loopq = Expr (:for , Expr (:(= ), itersym, looprange), blockq )
449
533
for op ∈ operations (ls)
450
534
# won't define if already defined...
451
535
id = identifier (op)
0 commit comments