Skip to content

Commit 0e4dad0

Browse files
committed
Minor update.
1 parent 5762657 commit 0e4dad0

File tree

1 file changed

+32
-46
lines changed

1 file changed

+32
-46
lines changed

src/graphs.jl

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -807,11 +807,32 @@ function lower!(
807807
)
808808
foreach(op -> lower!(q, op, W, unrolled, U, suffix, mask), ops)
809809
end
810+
function lower_load!(
811+
q::Expr, ops::AbstractVector{Operation}, W::Int, unrolled::Symbol, U::Int,
812+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
813+
)
814+
foreach(op -> lower_load!(q, op, W, unrolled, U, suffix, mask), ops)
815+
end
816+
function lower_compute!(
817+
q::Expr, ops::AbstractVector{Operation}, W::Int, unrolled::Symbol, U::Int,
818+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
819+
)
820+
foreach(op -> lower_compute!(q, op, W, unrolled, U, suffix, mask), ops)
821+
end
822+
function lower_store!(
823+
q::Expr, ops::AbstractVector{Operation}, W::Int, unrolled::Symbol, U::Int,
824+
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
825+
)
826+
foreach(op -> lower_store!(q, op, W, unrolled, U, suffix, mask), ops)
827+
end
810828
function lower!(
811-
q::Expr, op::Operation, W::Int, unrolled::Symbol, U::Int,
829+
q::Expr, ops::AbstractVector{<:AbstractVector{Operation}}, W::Int, unrolled::Symbol, U::Int,
812830
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
813831
)
814-
foreach(op -> lower!(q, op, W, unrolled, U, suffix, mask), ops)
832+
833+
foreach(op -> lower_load!(q, op, W, unrolled, U, suffix, mask), ops[1])
834+
foreach(op -> lower_compute!(q, op, W, unrolled, U, suffix, mask), ops[2])
835+
foreach(op -> lower_store!(q, op, W, unrolled, U, suffix, mask), ops[3])
815836
end
816837

817838

@@ -840,6 +861,10 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
840861
tiled = Symbol("##UNDEFINED##")
841862
end
842863
local loopq_old::Expr
864+
# Probably delete peel
865+
# Have this function generate entire body
866+
# And pass custom range, or range strategy
867+
# to be used for unrolled and optionally tiled dims.
843868
for n 1:nloops - peel
844869
loopsym = order[n]
845870
blockq = if n == 1
@@ -850,14 +875,14 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
850875
loopq = Expr(:while, looprange(ls, loopsym), blockq)
851876
for prepost 1:2
852877
# !U && !T
853-
lower_scalar!(blockq, @view(ops[:,1,1,prepost,n]), W, unrolled, U, nothing, mask)
878+
lower!(blockq, @view(ops[:,1,1,prepost,n]), W, unrolled, U, nothing, mask)
854879
for u 0:U-1 # U && !T
855-
lower_unrolled!(blockq, @view(ops[:,2,1,prepost,n]), W, unrolled, U, nothing, mask)
880+
lower!(blockq, @view(ops[:,2,1,prepost,n]), W, unrolled, U, nothing, mask)
856881
end
857882
for t 0:Titer # !U && T
858-
lower_scalar!(blockq, @view(ops[:,1,2,prepost,n]), W, unrolled, U, t, mask)
883+
lower!(blockq, @view(ops[:,1,2,prepost,n]), W, unrolled, U, t, mask)
859884
for u 0:U-1 # U && T
860-
lower_unrolled!(blockq, @view(ops[:,2,2,prepost,n]), W, unrolled, U, t, mask)
885+
lower!(blockq, @view(ops[:,2,2,prepost,n]), W, unrolled, U, t, mask)
861886
end
862887
end
863888
if n > 1 && prepost == 1
@@ -866,46 +891,7 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
866891
end
867892
loopq_old = loopq
868893
end
869-
870-
@assert peel 0
871-
# this function create the inner block
872-
# args = Any[]
873-
nloops = length(order)
874-
unrolled = first(order)
875-
# included_syms = Set( (unrolled,) )
876-
included_vars = fill(false, length(operations(ls)))
877-
# to go inside out, we just have to include all those not-yet included depending on the current sym
878-
n = 0
879-
loopsym = last(order)
880-
blockq = Expr(:block, )#Expr(:(=), loopsym, 0))
881-
loopq = Expr(:while, looprange(ls, loopsym), blockq)
882-
for (id,op) enumerate(operations(ls))
883-
# We add an op the first time all loop dependencies are met
884-
# when working through loops backwords, that equates to the first time we encounter a loop dependency
885-
loopsym dependencies(op) || continue
886-
included_vars[id] = true
887-
lower!(blockq, op, unrolled, U)
888-
end
889-
for n 1:nloops - 1 - peel
890-
blockq = Expr(:block, Expr(:(=), loopsym, 0)) # sets old loopsym to 0
891-
loopsym = order[nloops - n]
892-
postloop = Expr(:block, )
893-
for (id,op) enumerate(operations(ls))
894-
included_vars[id] && continue
895-
# We add an op the first time all loop dependencies are met
896-
# when working through loops backwords, that equates to the first time we encounter a loop dependency
897-
loopsym dependencies(op) || continue
898-
included_vars[id] = true
899-
900-
after_loop = depends_on_assigned(op, included_vars)
901-
after_loop || lower!(blockq, op, unrolled, U)
902-
after_loop && lower!(postloop, op, unrolled, U)
903-
end
904-
push!(blockq.args, loopq_old); append!(blockq.args, postloop.args)
905-
push!(blockq, Expr(:+=, loopsym, 1))
906-
loopq = Expr(:while, looprange(ls, loopsym), blockq)
907-
end
908-
Expr(:block, Expr(:=, order[1 + peel], 0), loopq), included_vars
894+
loopq
909895
end
910896
function lower_unroll_static(ls::LoopSet, order::Vector{Symbol}, U::Int)
911897

0 commit comments

Comments
 (0)