@@ -648,8 +648,13 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
648
648
initialize_outer_reductions! (ifq, ls, Ulow, Uhigh, W, typeT, unrolled)
649
649
push! (ifq. args, loopq)
650
650
reduce_range! (ifq, ls, Ulow, Uhigh)
651
- comparison = Expr (:call , :! , Expr (:call , :< , unrolledloop. rangesym, Expr (:call , lv (:valmul ), W, Uhigh)))
652
- Expr (:if , comparison, ifq)
651
+ comparison = if unrolledloop. hintexact
652
+ Expr (:call , :< , unrolledloop. rangehint, Expr (:call , lv (:valmul ), W, Uhigh))
653
+ else
654
+ Expr (:call , :< , unrolledloop. rangesym, Expr (:call , lv (:valmul ), W, Uhigh))
655
+ end
656
+ ncomparison = Expr (:call , :! , comparison)
657
+ Expr (:if , ncomparison, ifq)
653
658
end
654
659
function reduce_expr! (q:: Expr , ls:: LoopSet , U:: Int )
655
660
for or ∈ ls. outer_reductions
@@ -805,10 +810,19 @@ function lower_unrolled_dynamic!(
805
810
end
806
811
else
807
812
remblocknew = if unrolled === vectorized
808
- comparison = Expr (:call , :> , unrolled, Expr (:call , :- , unrolled_numitersym, Expr (:call , lv (:valmuladd ), W, Ut, 1 )))
813
+ itercount = if unrolledloop. hintexact
814
+ Expr (:call , :- , unrolledloop. rangehint, Expr (:call , lv (:valmuladd ), W, Ut, 1 ))
815
+ else
816
+ Expr (:call , :- , unrolled_numitersym, Expr (:call , lv (:valmuladd ), W, Ut, 1 ))
817
+ end
818
+ comparison = Expr (:call , :> , unrolled, itercount)
809
819
Expr (Ut == 1 ? :if : :elseif , comparison, lower_set (ls, vectorized, Ut, T, W, Symbol (" ##mask##" ), :block ))
810
820
else
811
- comparison = Expr (:call , :> , unrolled, Expr (:call , :- , unrolled_numitersym, Ut + 1 ))
821
+ comparison = if unrolledloop. hintexact
822
+ Expr (:call , :> , unrolled, unrolledloop. rangehint - (Ut + 1 ))
823
+ else
824
+ Expr (:call , :> , unrolled, Expr (:call , :- , unrolled_numitersym, Ut + 1 ))
825
+ end
812
826
Expr (Ut == 1 ? :if : :elseif , comparison, lower_set (ls, vectorized, Ut, T, W, nothing , :block ))
813
827
end
814
828
push! (remblock. args, remblocknew)
@@ -824,7 +838,11 @@ function lower_unrolled_dynamic!(
824
838
end
825
839
Ut = 1
826
840
# setup for branchy remainder calculation
827
- comparison = Expr (:call , :(!= ), unrolled_numitersym, unrolled)
841
+ comparison = if unrolledloop. hintexact
842
+ Expr (:call , :(!= ), unrolledloop. rangehint, unrolled)
843
+ else
844
+ Expr (:call , :(!= ), unrolled_numitersym, unrolled)
845
+ end
828
846
remblock = Expr (:block )
829
847
push! (q. args, Expr (:if , comparison, remblock))
830
848
else
@@ -857,7 +875,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
857
875
W = gensym (:W )
858
876
typeT = gensym (:T )
859
877
setup_Wmask! (ls, W, typeT, vectorized, unrolled, U)
860
- # W = VectorizationBase.pick_vector_width(ls, unrolled)
861
878
tiledloop = ls. loops[tiled]
862
879
static_tile = tiledloop. hintexact
863
880
unrolledloop = ls. loops[unrolled]
@@ -866,8 +883,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
866
883
# we build up the loop expression.
867
884
Trem = Tt = T
868
885
nloops = num_loops (ls);
869
- # addtileonly = sum(length, @view(oporder(ls)[:,:,:,:,end])) > 0
870
- # Texprtype = (static_tile && tiled_iter < 2T) ? :block : :while
871
886
firstiter = true
872
887
mangledtiled = tiledsym (tiled)
873
888
local qifelse:: Expr
@@ -876,7 +891,7 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
876
891
lower_unrolled! (tiledloopbody, ls, vectorized, U, Tt, W, typeT, unrolledloop)
877
892
tiledloopbody = lower_nest (ls, nloops, vectorized, U, Tt, tiledloopbody, 0 , W, nothing , :block )
878
893
if firstiter
879
- push! (q. args, (static_tile && tiled_iter < 2 T) ? tiledloopbody : Expr (:while , looprange (ls, tiled, Tt, mangledtiled, tiledloop), tiledloopbody))
894
+ push! (q. args, (static_tile && tiledloop . rangehint < 2 T) ? tiledloopbody : Expr (:while , looprange (ls, tiled, Tt, mangledtiled, tiledloop), tiledloopbody))
880
895
elseif static_tile
881
896
push! (q. args, tiledloopbody)
882
897
else # not static, not firstiter
@@ -887,7 +902,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
887
902
end
888
903
if static_tile
889
904
if Tt == T
890
- # push!(tiledloopbody.args, Expr(:+=, mangledtiled, Tt))
891
905
Texprtype = :block
892
906
Tt = looprangehint (ls, tiled) % T
893
907
# Recalculate U
0 commit comments