@@ -70,7 +70,6 @@ struct Operation
70
70
node_type,
71
71
variable = gensym ()
72
72
)
73
- # identifier = Threads.atomic_add!(ID, one(UInt))
74
73
new (
75
74
identifier, variable, elementbytes, instruction, node_type,
76
75
Set {Symbol} (), Operation[], Operation[], Int[], Symbol[]# , Dict{Symbol,Union{Symbol,Int}}()
@@ -165,7 +164,7 @@ isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
165
164
itersyms (ls:: LoopSet ) = keys (ls. loops)
166
165
function looprange (ls:: LoopSet , s:: Symbol )
167
166
loop = ls. loops[s]
168
- Expr (:(:), 0 , loop. hintexact ? loop. rangehint - 1 : Expr ( :call , :( - ), loop. rangesym, 1 ) )
167
+ Expr (:call , : < , s , loop. hintexact ? loop. rangehint : loop. rangesym)
169
168
end
170
169
function Base. length (ls:: LoopSet , is:: Symbol )
171
170
ls. loops[is]. rangehint
@@ -499,7 +498,8 @@ function depends_on_assigned(op::Operation, assigned::Vector{Bool})
499
498
end
500
499
false
501
500
end
502
- function replace_ind_in_offset! (offset:: Vector , op:: Operation , ind:: Int , dynamic:: Bool , t)
501
+ # ind gets increased across tiles / unroll, so we need steps.
502
+ function replace_ind_in_offset! (offset:: Vector , op:: Operation , ind:: Int , t)
503
503
t == 0 && return nothing
504
504
var = op. variable
505
505
siter = op. symbolic_metadata[ind]
@@ -538,7 +538,7 @@ function lower_load!(
538
538
end
539
539
else # tiling
540
540
for t ∈ 0 : T- 1
541
- replace_ind_inoffset ! (memoff, op, tind, t)
541
+ replace_ind_in_offset ! (memoff, op, tind, t)
542
542
for u ∈ 0 : U- 1
543
543
memoff2 = copy (memoff)
544
544
u > 0 && push! (memoff2, W* u)
@@ -553,7 +553,7 @@ function lower_load!(
553
553
ustrides = Expr (:tuple , (ustride > 1 ? [Core. VecElement {Int} (ustride* w) for w ∈ 0 : W- 1 ] : [:(Core. VecElement {Int} ($ (op. symbolic_metadata[upos])* $ w)) for w ∈ 0 : W- 1 ]). .. )
554
554
if T != - 1 # gather tile
555
555
for t ∈ 0 : T- 1
556
- replace_ind_inoffset ! (memoff, op, tind, t)
556
+ replace_ind_in_offset ! (memoff, op, tind, t)
557
557
for u ∈ 0 : U- 1
558
558
memoff2 = copy (memoff)
559
559
u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
@@ -580,7 +580,7 @@ function lower_load!(
580
580
# load per T.
581
581
# memoff2 = copy(memoff)
582
582
for t ∈ 0 : T- 1
583
- replace_ind_inoffset ! (memoff, op, tind, t)
583
+ replace_ind_in_offset ! (memoff, op, tind, t)
584
584
instrcall = Expr (:call , :load , ptr, copy (memoff))
585
585
# mask === nothing || push!(instrcall.args, mask)
586
586
push! (q. args, Expr (:(= ), Symbol (var,:_ ,t), instrcall))
@@ -613,7 +613,7 @@ function lower_store!(
613
613
end
614
614
else # tiling
615
615
for t ∈ 0 : T- 1
616
- replace_ind_inoffset ! (memoff, op, tind, t)
616
+ replace_ind_in_offset ! (memoff, op, tind, t)
617
617
for u ∈ 0 : U- 1
618
618
memoff2 = copy (memoff)
619
619
u > 0 && push! (memoff2, W* u)
@@ -628,7 +628,7 @@ function lower_store!(
628
628
ustrides = Expr (:tuple , (ustride > 1 ? [Core. VecElement {Int} (ustride* w) for w ∈ 0 : W- 1 ] : [:(Core. VecElement {Int} ($ (op. symbolic_metadata[upos])* $ w)) for w ∈ 0 : W- 1 ]). .. )
629
629
if T != - 1 # gather tile
630
630
for t ∈ 0 : T- 1
631
- replace_ind_inoffset ! (memoff, op, tind, t)
631
+ replace_ind_in_offset ! (memoff, op, tind, t)
632
632
for u ∈ 0 : U- 1
633
633
memoff2 = copy (memoff)
634
634
u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
@@ -658,11 +658,13 @@ function lower_store!(
658
658
# store per T.
659
659
# memoff2 = copy(memoff)
660
660
for t ∈ 0 : T- 1
661
- replace_ind_inoffset! (memoff, op, tind, t)
662
- push! (q. args, Expr (:call , :store! , ptr, Symbol (var,:_ ,t), copy (memoff)))
661
+ replace_ind_in_offset! (memoff, op, tind, t)
662
+ storevar = Expr (:call , reduct, Symbol (var,:_ ,t))
663
+ push! (q. args, Expr (:call , :store! , ptr, storevar, copy (memoff)))
663
664
end
664
665
else # no unroll
665
- push! (q. args, Expr (:call , :store! , var, ptr, memoff))
666
+ storevar = Expr (:call , reduct, var)
667
+ push! (q. args, Expr (:call , :store! , ptr, storevar, memoff))
666
668
end
667
669
end
668
670
end
@@ -673,6 +675,7 @@ function lower_compute!(
673
675
U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" ), mask = nothing
674
676
)
675
677
opunrolled = unrolled ∈ loopdependencies (op)
678
+
676
679
optiled = tiled ∈ loopdependencies (op)
677
680
var = op. variable
678
681
instr = op. instruction
@@ -745,7 +748,8 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
745
748
lower_unroll_dynamic (ls, order, U)
746
749
end
747
750
end
748
- function lower_unroll_inner_block (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
751
+ function lower_unroll_inner_block (ls:: LoopSet , order:: Vector{Symbol} , U:: Int , peel:: Int = 1 )
752
+ @assert peel ≥ 0
749
753
# this function create the inner block
750
754
# args = Any[]
751
755
nloops = length (order)
@@ -755,39 +759,68 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
755
759
# to go inside out, we just have to include all those not-yet included depending on the current sym
756
760
n = 0
757
761
loopsym = last (order)
758
- blockq = Expr (:block , )
759
- loopq = Expr (:for , Expr (:( = ), itersym, looprange (ls, loopsym) ), blockq)
762
+ blockq = Expr (:block , )# Expr(:(=), loopsym, 0))
763
+ loopq = Expr (:while , looprange (ls, loopsym), blockq)
760
764
for (id,op) ∈ enumerate (operations (ls))
761
765
# We add an op the first time all loop dependencies are met
762
766
# when working through loops backwords, that equates to the first time we encounter a loop dependency
763
767
loopsym ∈ dependencies (op) || continue
764
768
included_vars[id] = true
765
769
lower! (blockq, op, unrolled, U)
766
770
end
767
- for n ∈ 1 : nloops - 2
771
+ for n ∈ 1 : nloops - 1 - peel
772
+ blockq = Expr (:block , Expr (:(= ), loopsym, 0 )) # sets old loopsym to 0
768
773
loopsym = order[nloops - n]
769
- blockq = Expr (:block , )
770
774
postloop = Expr (:block , )
771
775
for (id,op) ∈ enumerate (operations (ls))
772
776
included_vars[id] && continue
773
777
# We add an op the first time all loop dependencies are met
774
778
# when working through loops backwords, that equates to the first time we encounter a loop dependency
775
779
loopsym ∈ dependencies (op) || continue
776
780
included_vars[id] = true
777
-
781
+
778
782
after_loop = depends_on_assigned (op, included_vars)
779
783
after_loop || lower! (blockq, op, unrolled, U)
780
784
after_loop && lower! (postloop, op, unrolled, U)
781
785
end
782
786
push! (blockq. args, loopq_old); append! (blockq. args, postloop. args)
783
- loopq = Expr (:for , Expr (:(= ), itersym, looprange), blockq)
787
+ push! (blockq, Expr (:+= , loopsym, 1 ))
788
+ loopq = Expr (:while , looprange (ls, loopsym), blockq)
784
789
end
785
- loopq
790
+ Expr ( :block , Expr ( := , order[ 1 + peel], 0 ), loopq), included_vars
786
791
end
787
792
function lower_unroll_static (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
788
793
789
794
end
790
795
function lower_unroll_dynamic (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
796
+
797
+
798
+ unrolled = first (order)
799
+ q = Expr (:block , )
800
+
801
+ # we repeatedly break into smaller chunks.
802
+ while U > 0
803
+ inner_block, included_vars = lower_unroll_inner_block (ls, order, U, 1 )
804
+
805
+ end
806
+
807
+ Uispow2 = VectorizationBase. ispow2 (U)
808
+ looprange (ls, loopsym)
809
+
810
+ loop = ls. loops[s]
811
+ Expr (:(:), 0 , loop. hintexact ? loop. rangehint - 1 : Expr (:call , :(- ), loop. rangesym, 1 ))
812
+
813
+ if U == 1 # no unrolling needed
814
+
815
+ elseif Uispow2 # we use shifts and bitwise &
816
+ log2U = VectorizationBase. intlog2 (U)
817
+
818
+ else
819
+
820
+ end
821
+
822
+ # now must repeat inner block
823
+
791
824
nested_loop_syms = Set {Symbol} ()
792
825
# included_vars = Set{UInt}()
793
826
included_vars = fill (false , length (operations (ls)))
@@ -812,6 +845,7 @@ function lower_unroll_dynamic(ls::LoopSet, order::Vector{Symbol}, U::Int)
812
845
loopq = looprange (ls:: LoopSet , s:: Symbol )
813
846
end
814
847
blockq = Expr (:block , )
848
+
815
849
loopq = Expr (:for , Expr (:(= ), itersym, looprange), blockq)
816
850
for op ∈ operations (ls)
817
851
# won't define if already defined...
0 commit comments