1
1
# function unitstride(op::Operation, sym::Symbol)
2
2
# (first(op.symbolic_metadata) === sym) && (first(op.numerical_metadata) == 1)
3
3
# end
4
- function mem_offset (op:: Operation , incr:: Int = 0 ):: Union{Symbol,Expr}
4
+ function mem_offset (op:: Operation , incr:: Int = 0 )
5
5
@assert accesses_memory (op) " Computing memory offset only makes sense for operations that access memory."
6
6
ret = Expr (:tuple , )
7
7
deps = op. dependencies
@@ -15,7 +15,7 @@ function mem_offset(op::Operation, incr::Int = 0)::Union{Symbol,Expr}
15
15
end
16
16
ret
17
17
end
18
- function mem_offset (op:: Operation , incr:: Int , unrolled:: Symbol ):: Union{Symbol,Expr}
18
+ function mem_offset (op:: Operation , incr:: Int , unrolled:: Symbol )
19
19
@assert accesses_memory (op) " Computing memory offset only makes sense for operations that access memory."
20
20
ret = Expr (:tuple , )
21
21
deps = op. dependencies
@@ -268,10 +268,23 @@ function lower_compute!(
268
268
var = op. variable
269
269
parents_op = parents (op)
270
270
nparents = length (parents_op)
271
+ if opunrolled
272
+ parentsunrolled = Vector {Bool} (undef, nparents)
273
+ for (p,opp) ∈ enumerate (parents_op)
274
+ # if op is an inner reduction, one of its parents will be the initialization of op
275
+ # They will share the same `variable` field. The initialization may not have
276
+ # unrolled in its loop dependencies, but (if opunrolled) op itself is, so we return true
277
+ parentsunrolled[p] = var === opp. variable ? true : (unrolled ∈ loopdependencies (opp))
278
+ end
279
+ else # maybe skip allocating this?
280
+ parentsunrolled = fill (false , nparents)
281
+ end
271
282
parentstiled = if suffix === nothing
272
283
optiled = false
284
+ tiledouterreduction = false
273
285
fill (false , nparents)
274
286
else
287
+ tiledouterreduction = identifier (op) ∈
275
288
var = Symbol (var, :_ , suffix)
276
289
optiled = true
277
290
[tiled ∈ loopdependencies (opp) for opp ∈ parents_op]
@@ -280,7 +293,7 @@ function lower_compute!(
280
293
# cache unroll and tiling check of parents
281
294
# not broadcasted, because we use frequent checks of individual bools
282
295
# making BitArrays inefficient.
283
- parentsunrolled = opunrolled ? [unrolled ∈ loopdependencies (opp) for opp ∈ parents_op] : fill ( false , nparents)
296
+ @show instr parentsunrolled
284
297
# parentsyms = [opp.variable for opp ∈ parents(op)]
285
298
Uiter = opunrolled ? U - 1 : 0
286
299
maskreduct = mask != = nothing && isreduction (op)# any(opp -> opp.variable === var, parents_op)
@@ -416,6 +429,7 @@ function lower_nest(
416
429
istiled = T != - 1
417
430
loopsym = order[n]
418
431
nloops = num_loops (ls)
432
+ outer_reduce = length (ls. outer_reductions) > 0
419
433
if istiled
420
434
if n == nloops
421
435
loopsym = tiledsym (loopsym)
@@ -433,15 +447,12 @@ function lower_nest(
433
447
loopincr = n == nloops ? U* W : 1
434
448
end
435
449
@show unrolled, order
436
- blockq = if n == 1
437
- Expr (:block , )
438
- else
439
- Expr (:block , Expr (:(= ), order[n- 1 ], loopstart))
440
- end
450
+ blockq = Expr (:block , )
451
+ n == 1 || push! (blockq. args, Expr (:(= ), order[n- 1 ], loopstart))
441
452
loopq = if exprtype === :block
442
453
blockq
443
454
else
444
- @assert exprtype === :while || exprtype === :if
455
+ @assert exprtype === :while || exprtype === :if " Expression type $exprtype not recognized. "
445
456
Expr (exprtype, looprange (ls, loopsym, loopincr), blockq)
446
457
end
447
458
for prepost ∈ 1 : 2
@@ -603,7 +614,7 @@ function lower_tiled(ls::LoopSet, U::Int, T::Int)
603
614
Texprtype = (static_tile && tiled_iter < 2 T) ? :block : :while
604
615
while Tt > 0
605
616
tiledloopbody = Expr (:block , Expr (:(= ), unrolled, 0 ))
606
- push! (q. args, Texprtype === :block ? tiledloopbody : Expr (Texprtype, looprange (ls, tiled, Tt), tiledloopbody))
617
+ push! (q. args, Texprtype === :block ? tiledloopbody : Expr (Texprtype, looprange (ls, tiledsym ( tiled) , Tt), tiledloopbody))
607
618
lower_unrolled! (tiledloopbody, ls, U, Tt, W, static_unroll, unrolled_iter, unrolled_itersym)
608
619
if static_tile
609
620
Tt = if Tt == T
0 commit comments