@@ -807,11 +807,32 @@ function lower!(
807
807
)
808
808
foreach (op -> lower! (q, op, W, unrolled, U, suffix, mask), ops)
809
809
end
810
+ function lower_load! (
811
+ q:: Expr , ops:: AbstractVector{Operation} , W:: Int , unrolled:: Symbol , U:: Int ,
812
+ suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
813
+ )
814
+ foreach (op -> lower_load! (q, op, W, unrolled, U, suffix, mask), ops)
815
+ end
816
+ function lower_compute! (
817
+ q:: Expr , ops:: AbstractVector{Operation} , W:: Int , unrolled:: Symbol , U:: Int ,
818
+ suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
819
+ )
820
+ foreach (op -> lower_compute! (q, op, W, unrolled, U, suffix, mask), ops)
821
+ end
822
+ function lower_store! (
823
+ q:: Expr , ops:: AbstractVector{Operation} , W:: Int , unrolled:: Symbol , U:: Int ,
824
+ suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
825
+ )
826
+ foreach (op -> lower_store! (q, op, W, unrolled, U, suffix, mask), ops)
827
+ end
810
828
function lower! (
811
- q:: Expr , op :: Operation , W:: Int , unrolled:: Symbol , U:: Int ,
829
+ q:: Expr , ops :: AbstractVector{<:AbstractVector{ Operation}} , W:: Int , unrolled:: Symbol , U:: Int ,
812
830
suffix:: Union{Nothing,Int} , mask:: Union{Nothing,Symbol,Unsigned} = nothing
813
831
)
814
- foreach (op -> lower! (q, op, W, unrolled, U, suffix, mask), ops)
832
+
833
+ foreach (op -> lower_load! (q, op, W, unrolled, U, suffix, mask), ops[1 ])
834
+ foreach (op -> lower_compute! (q, op, W, unrolled, U, suffix, mask), ops[2 ])
835
+ foreach (op -> lower_store! (q, op, W, unrolled, U, suffix, mask), ops[3 ])
815
836
end
816
837
817
838
@@ -840,6 +861,10 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
840
861
tiled = Symbol (" ##UNDEFINED##" )
841
862
end
842
863
local loopq_old:: Expr
864
+ # Probably delete peel
865
+ # Have this function generate entire body
866
+ # And pass custom range, or range strategy
867
+ # to be used for unrolled and optionally tiled dims.
843
868
for n ∈ 1 : nloops - peel
844
869
loopsym = order[n]
845
870
blockq = if n == 1
@@ -850,14 +875,14 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
850
875
loopq = Expr (:while , looprange (ls, loopsym), blockq)
851
876
for prepost ∈ 1 : 2
852
877
# !U && !T
853
- lower_scalar ! (blockq, @view (ops[:,1 ,1 ,prepost,n]), W, unrolled, U, nothing , mask)
878
+ lower ! (blockq, @view (ops[:,1 ,1 ,prepost,n]), W, unrolled, U, nothing , mask)
854
879
for u ∈ 0 : U- 1 # U && !T
855
- lower_unrolled ! (blockq, @view (ops[:,2 ,1 ,prepost,n]), W, unrolled, U, nothing , mask)
880
+ lower ! (blockq, @view (ops[:,2 ,1 ,prepost,n]), W, unrolled, U, nothing , mask)
856
881
end
857
882
for t ∈ 0 : Titer # !U && T
858
- lower_scalar ! (blockq, @view (ops[:,1 ,2 ,prepost,n]), W, unrolled, U, t, mask)
883
+ lower ! (blockq, @view (ops[:,1 ,2 ,prepost,n]), W, unrolled, U, t, mask)
859
884
for u ∈ 0 : U- 1 # U && T
860
- lower_unrolled ! (blockq, @view (ops[:,2 ,2 ,prepost,n]), W, unrolled, U, t, mask)
885
+ lower ! (blockq, @view (ops[:,2 ,2 ,prepost,n]), W, unrolled, U, t, mask)
861
886
end
862
887
end
863
888
if n > 1 && prepost == 1
@@ -866,46 +891,7 @@ function lower_inner_block(ls::LoopSet, U::Int, T::Int, peel::Int = 1)
866
891
end
867
892
loopq_old = loopq
868
893
end
869
-
870
- @assert peel ≥ 0
871
- # this function create the inner block
872
- # args = Any[]
873
- nloops = length (order)
874
- unrolled = first (order)
875
- # included_syms = Set( (unrolled,) )
876
- included_vars = fill (false , length (operations (ls)))
877
- # to go inside out, we just have to include all those not-yet included depending on the current sym
878
- n = 0
879
- loopsym = last (order)
880
- blockq = Expr (:block , )# Expr(:(=), loopsym, 0))
881
- loopq = Expr (:while , looprange (ls, loopsym), blockq)
882
- for (id,op) ∈ enumerate (operations (ls))
883
- # We add an op the first time all loop dependencies are met
884
- # when working through loops backwords, that equates to the first time we encounter a loop dependency
885
- loopsym ∈ dependencies (op) || continue
886
- included_vars[id] = true
887
- lower! (blockq, op, unrolled, U)
888
- end
889
- for n ∈ 1 : nloops - 1 - peel
890
- blockq = Expr (:block , Expr (:(= ), loopsym, 0 )) # sets old loopsym to 0
891
- loopsym = order[nloops - n]
892
- postloop = Expr (:block , )
893
- for (id,op) ∈ enumerate (operations (ls))
894
- included_vars[id] && continue
895
- # We add an op the first time all loop dependencies are met
896
- # when working through loops backwords, that equates to the first time we encounter a loop dependency
897
- loopsym ∈ dependencies (op) || continue
898
- included_vars[id] = true
899
-
900
- after_loop = depends_on_assigned (op, included_vars)
901
- after_loop || lower! (blockq, op, unrolled, U)
902
- after_loop && lower! (postloop, op, unrolled, U)
903
- end
904
- push! (blockq. args, loopq_old); append! (blockq. args, postloop. args)
905
- push! (blockq, Expr (:+= , loopsym, 1 ))
906
- loopq = Expr (:while , looprange (ls, loopsym), blockq)
907
- end
908
- Expr (:block , Expr (:= , order[1 + peel], 0 ), loopq), included_vars
894
+ loopq
909
895
end
910
896
function lower_unroll_static (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
911
897
0 commit comments