Skip to content

Commit 38e32b8

Browse files
committed
Minor progress on lowering.
1 parent dc0252b commit 38e32b8

File tree

1 file changed

+53
-19
lines changed

1 file changed

+53
-19
lines changed

src/graphs.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ struct Operation
7070
node_type,
7171
variable = gensym()
7272
)
73-
# identifier = Threads.atomic_add!(ID, one(UInt))
7473
new(
7574
identifier, variable, elementbytes, instruction, node_type,
7675
Set{Symbol}(), Operation[], Operation[], Int[], Symbol[]#, Dict{Symbol,Union{Symbol,Int}}()
@@ -165,7 +164,7 @@ isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
165164
itersyms(ls::LoopSet) = keys(ls.loops)
166165
function looprange(ls::LoopSet, s::Symbol)
167166
loop = ls.loops[s]
168-
Expr(:(:), 0, loop.hintexact ? loop.rangehint - 1 : Expr(:call, :(-), loop.rangesym, 1))
167+
Expr(:call, :<, s, loop.hintexact ? loop.rangehint : loop.rangesym)
169168
end
170169
function Base.length(ls::LoopSet, is::Symbol)
171170
ls.loops[is].rangehint
@@ -499,7 +498,8 @@ function depends_on_assigned(op::Operation, assigned::Vector{Bool})
499498
end
500499
false
501500
end
502-
function replace_ind_in_offset!(offset::Vector, op::Operation, ind::Int, dynamic::Bool, t)
501+
# ind gets increased across tiles / unroll, so we need steps.
502+
function replace_ind_in_offset!(offset::Vector, op::Operation, ind::Int, t)
503503
t == 0 && return nothing
504504
var = op.variable
505505
siter = op.symbolic_metadata[ind]
@@ -538,7 +538,7 @@ function lower_load!(
538538
end
539539
else # tiling
540540
for t 0:T-1
541-
replace_ind_inoffset!(memoff, op, tind, t)
541+
replace_ind_in_offset!(memoff, op, tind, t)
542542
for u 0:U-1
543543
memoff2 = copy(memoff)
544544
u > 0 && push!(memoff2, W*u)
@@ -553,7 +553,7 @@ function lower_load!(
553553
ustrides = Expr(:tuple, (ustride > 1 ? [Core.VecElement{Int}(ustride*w) for w 0:W-1] : [:(Core.VecElement{Int}($(op.symbolic_metadata[upos])*$w)) for w 0:W-1])...)
554554
if T != -1 # gather tile
555555
for t 0:T-1
556-
replace_ind_inoffset!(memoff, op, tind, t)
556+
replace_ind_in_offset!(memoff, op, tind, t)
557557
for u 0:U-1
558558
memoff2 = copy(memoff)
559559
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
@@ -580,7 +580,7 @@ function lower_load!(
580580
# load per T.
581581
# memoff2 = copy(memoff)
582582
for t 0:T-1
583-
replace_ind_inoffset!(memoff, op, tind, t)
583+
replace_ind_in_offset!(memoff, op, tind, t)
584584
instrcall = Expr(:call, :load, ptr, copy(memoff))
585585
# mask === nothing || push!(instrcall.args, mask)
586586
push!(q.args, Expr(:(=), Symbol(var,:_,t), instrcall))
@@ -613,7 +613,7 @@ function lower_store!(
613613
end
614614
else # tiling
615615
for t 0:T-1
616-
replace_ind_inoffset!(memoff, op, tind, t)
616+
replace_ind_in_offset!(memoff, op, tind, t)
617617
for u 0:U-1
618618
memoff2 = copy(memoff)
619619
u > 0 && push!(memoff2, W*u)
@@ -628,7 +628,7 @@ function lower_store!(
628628
ustrides = Expr(:tuple, (ustride > 1 ? [Core.VecElement{Int}(ustride*w) for w 0:W-1] : [:(Core.VecElement{Int}($(op.symbolic_metadata[upos])*$w)) for w 0:W-1])...)
629629
if T != -1 # gather tile
630630
for t 0:T-1
631-
replace_ind_inoffset!(memoff, op, tind, t)
631+
replace_ind_in_offset!(memoff, op, tind, t)
632632
for u 0:U-1
633633
memoff2 = copy(memoff)
634634
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
@@ -658,11 +658,13 @@ function lower_store!(
658658
# store per T.
659659
# memoff2 = copy(memoff)
660660
for t 0:T-1
661-
replace_ind_inoffset!(memoff, op, tind, t)
662-
push!(q.args, Expr(:call, :store!, ptr, Symbol(var,:_,t), copy(memoff)))
661+
replace_ind_in_offset!(memoff, op, tind, t)
662+
storevar = Expr(:call, reduct, Symbol(var,:_,t))
663+
push!(q.args, Expr(:call, :store!, ptr, storevar, copy(memoff)))
663664
end
664665
else # no unroll
665-
push!(q.args, Expr(:call, :store!, var, ptr, memoff))
666+
storevar = Expr(:call, reduct, var)
667+
push!(q.args, Expr(:call, :store!, ptr, storevar, memoff))
666668
end
667669
end
668670
end
@@ -673,6 +675,7 @@ function lower_compute!(
673675
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##"), mask = nothing
674676
)
675677
opunrolled = unrolled loopdependencies(op)
678+
676679
optiled = tiled loopdependencies(op)
677680
var = op.variable
678681
instr = op.instruction
@@ -745,7 +748,8 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
745748
lower_unroll_dynamic(ls, order, U)
746749
end
747750
end
748-
function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
751+
function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int, peel::Int = 1)
752+
@assert peel 0
749753
# this function create the inner block
750754
# args = Any[]
751755
nloops = length(order)
@@ -755,39 +759,68 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
755759
# to go inside out, we just have to include all those not-yet included depending on the current sym
756760
n = 0
757761
loopsym = last(order)
758-
blockq = Expr(:block, )
759-
loopq = Expr(:for, Expr(:(=), itersym, looprange(ls, loopsym)), blockq)
762+
blockq = Expr(:block, )#Expr(:(=), loopsym, 0))
763+
loopq = Expr(:while, looprange(ls, loopsym), blockq)
760764
for (id,op) enumerate(operations(ls))
761765
# We add an op the first time all loop dependencies are met
762766
# when working through loops backwords, that equates to the first time we encounter a loop dependency
763767
loopsym dependencies(op) || continue
764768
included_vars[id] = true
765769
lower!(blockq, op, unrolled, U)
766770
end
767-
for n 1:nloops - 2
771+
for n 1:nloops - 1 - peel
772+
blockq = Expr(:block, Expr(:(=), loopsym, 0)) # sets old loopsym to 0
768773
loopsym = order[nloops - n]
769-
blockq = Expr(:block, )
770774
postloop = Expr(:block, )
771775
for (id,op) enumerate(operations(ls))
772776
included_vars[id] && continue
773777
# We add an op the first time all loop dependencies are met
774778
# when working through loops backwords, that equates to the first time we encounter a loop dependency
775779
loopsym dependencies(op) || continue
776780
included_vars[id] = true
777-
781+
778782
after_loop = depends_on_assigned(op, included_vars)
779783
after_loop || lower!(blockq, op, unrolled, U)
780784
after_loop && lower!(postloop, op, unrolled, U)
781785
end
782786
push!(blockq.args, loopq_old); append!(blockq.args, postloop.args)
783-
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
787+
push!(blockq, Expr(:+=, loopsym, 1))
788+
loopq = Expr(:while, looprange(ls, loopsym), blockq)
784789
end
785-
loopq
790+
Expr(:block, Expr(:=, order[1 + peel], 0), loopq), included_vars
786791
end
787792
function lower_unroll_static(ls::LoopSet, order::Vector{Symbol}, U::Int)
788793

789794
end
790795
function lower_unroll_dynamic(ls::LoopSet, order::Vector{Symbol}, U::Int)
796+
797+
798+
unrolled = first(order)
799+
q = Expr(:block, )
800+
801+
# we repeatedly break into smaller chunks.
802+
while U > 0
803+
inner_block, included_vars = lower_unroll_inner_block(ls, order, U, 1)
804+
805+
end
806+
807+
Uispow2 = VectorizationBase.ispow2(U)
808+
looprange(ls, loopsym)
809+
810+
loop = ls.loops[s]
811+
Expr(:(:), 0, loop.hintexact ? loop.rangehint - 1 : Expr(:call, :(-), loop.rangesym, 1))
812+
813+
if U == 1 # no unrolling needed
814+
815+
elseif Uispow2 # we use shifts and bitwise &
816+
log2U = VectorizationBase.intlog2(U)
817+
818+
else
819+
820+
end
821+
822+
# now must repeat inner block
823+
791824
nested_loop_syms = Set{Symbol}()
792825
# included_vars = Set{UInt}()
793826
included_vars = fill(false, length(operations(ls)))
@@ -812,6 +845,7 @@ function lower_unroll_dynamic(ls::LoopSet, order::Vector{Symbol}, U::Int)
812845
loopq = looprange(ls::LoopSet, s::Symbol)
813846
end
814847
blockq = Expr(:block, )
848+
815849
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
816850
for op operations(ls)
817851
# won't define if already defined...

0 commit comments

Comments
 (0)