Skip to content

Commit 62c20f6

Browse files
committed
A few updates; commiting before merging.
1 parent d21d018 commit 62c20f6

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

src/lowering.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# function unitstride(op::Operation, sym::Symbol)
22
# (first(op.symbolic_metadata) === sym) && (first(op.numerical_metadata) == 1)
33
# end
4-
function mem_offset(op::Operation, incr::Int = 0)::Union{Symbol,Expr}
4+
function mem_offset(op::Operation, incr::Int = 0)
55
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
66
ret = Expr(:tuple, )
77
deps = op.dependencies
@@ -15,7 +15,7 @@ function mem_offset(op::Operation, incr::Int = 0)::Union{Symbol,Expr}
1515
end
1616
ret
1717
end
18-
function mem_offset(op::Operation, incr::Int, unrolled::Symbol)::Union{Symbol,Expr}
18+
function mem_offset(op::Operation, incr::Int, unrolled::Symbol)
1919
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
2020
ret = Expr(:tuple, )
2121
deps = op.dependencies
@@ -268,10 +268,23 @@ function lower_compute!(
268268
var = op.variable
269269
parents_op = parents(op)
270270
nparents = length(parents_op)
271+
if opunrolled
272+
parentsunrolled = Vector{Bool}(undef, nparents)
273+
for (p,opp) enumerate(parents_op)
274+
# if op is an inner reduction, one of its parents will be the initialization of op
275+
# They will share the same `variable` field. The initialization may not have
276+
# unrolled in its loop dependencies, but (if opunrolled) op itself is, so we return true
277+
parentsunrolled[p] = var === opp.variable ? true : (unrolled loopdependencies(opp))
278+
end
279+
else # maybe skip allocating this?
280+
parentsunrolled = fill(false, nparents)
281+
end
271282
parentstiled = if suffix === nothing
272283
optiled = false
284+
tiledouterreduction = false
273285
fill(false, nparents)
274286
else
287+
tiledouterreduction = identifier(op)
275288
var = Symbol(var, :_, suffix)
276289
optiled = true
277290
[tiled loopdependencies(opp) for opp parents_op]
@@ -280,7 +293,7 @@ function lower_compute!(
280293
# cache unroll and tiling check of parents
281294
# not broadcasted, because we use frequent checks of individual bools
282295
# making BitArrays inefficient.
283-
parentsunrolled = opunrolled ? [unrolled loopdependencies(opp) for opp parents_op] : fill(false, nparents)
296+
@show instr parentsunrolled
284297
# parentsyms = [opp.variable for opp ∈ parents(op)]
285298
Uiter = opunrolled ? U - 1 : 0
286299
maskreduct = mask !== nothing && isreduction(op)#any(opp -> opp.variable === var, parents_op)
@@ -416,6 +429,7 @@ function lower_nest(
416429
istiled = T != -1
417430
loopsym = order[n]
418431
nloops = num_loops(ls)
432+
outer_reduce = length(ls.outer_reductions) > 0
419433
if istiled
420434
if n == nloops
421435
loopsym = tiledsym(loopsym)
@@ -433,15 +447,12 @@ function lower_nest(
433447
loopincr = n == nloops ? U*W : 1
434448
end
435449
@show unrolled, order
436-
blockq = if n == 1
437-
Expr(:block, )
438-
else
439-
Expr(:block, Expr(:(=), order[n-1], loopstart))
440-
end
450+
blockq = Expr(:block, )
451+
n == 1 || push!(blockq.args, Expr(:(=), order[n-1], loopstart))
441452
loopq = if exprtype === :block
442453
blockq
443454
else
444-
@assert exprtype === :while || exprtype === :if
455+
@assert exprtype === :while || exprtype === :if "Expression type $exprtype not recognized."
445456
Expr(exprtype, looprange(ls, loopsym, loopincr), blockq)
446457
end
447458
for prepost 1:2
@@ -603,7 +614,7 @@ function lower_tiled(ls::LoopSet, U::Int, T::Int)
603614
Texprtype = (static_tile && tiled_iter < 2T) ? :block : :while
604615
while Tt > 0
605616
tiledloopbody = Expr(:block, Expr(:(=), unrolled, 0))
606-
push!(q.args, Texprtype === :block ? tiledloopbody : Expr(Texprtype, looprange(ls, tiled, Tt), tiledloopbody))
617+
push!(q.args, Texprtype === :block ? tiledloopbody : Expr(Texprtype, looprange(ls, tiledsym(tiled), Tt), tiledloopbody))
607618
lower_unrolled!(tiledloopbody, ls, U, Tt, W, static_unroll, unrolled_iter, unrolled_itersym)
608619
if static_tile
609620
Tt = if Tt == T

src/operations.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,20 @@ identifier(op::Operation) = op.identifier + 1
7777
name(op::Operation) = op.variable
7878
instruction(op::Operation) = op.instruction
7979

80+
function isouterreduction(op::Operation)
81+
if isconstant(op)
82+
op.instruction === Symbol("##CONSTANT##")
83+
elseif iscompute(op)
84+
var = op.variable
85+
for opp parents(op)
86+
opp.variable === var && opp.instruction === Symbol("##CONSTANT##") && return true
87+
end
88+
false
89+
else
90+
false
91+
end
92+
end
93+
8094
# function hasintersection(s1::Set{T}, s2::Set{T}) where {T}
8195
# for x ∈ s1
8296
# x ∈ s2 && return true

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ dotq = :(for i ∈ eachindex(a)
6666
lsdot = LoopVectorization.LoopSet(dotq);
6767
@test LoopVectorization.choose_order(lsdot) == (Symbol[:i], 8, -1)
6868
LoopVectorization.lower(lsdot)
69+
lsdot.operations
6970

7071
vexpq = :(for i eachindex(a)
7172
b[i] = exp(a[i])
@@ -80,6 +81,7 @@ vexpsq = :(for i ∈ eachindex(a)
8081
lsvexps = LoopVectorization.LoopSet(vexpsq);
8182
@test LoopVectorization.choose_order(lsvexps) == (Symbol[:i], 1, -1)
8283
LoopVectorization.lower(lsvexps)
84+
lsvexps.operations
8385

8486
gemvq = :(for i eachindex(y)
8587
yᵢ = 0.0

0 commit comments

Comments
 (0)