Skip to content

Commit f965c5f

Browse files
committed
Fixed a few bugs for handling static sizes.
1 parent e8f621a commit f965c5f

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

src/broadcast.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function add_broadcast!(
131131
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
132132
) where {T<:Union{Integer,Float32,Float64}}
133133
pushpreamble!(ls, Expr(:(=), Symbol("##", destname), bcname))
134-
add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ?
134+
add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ? u
135135
end
136136
function add_broadcast!(
137137
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -144,9 +144,9 @@ function add_broadcast!(
144144
end
145145
function add_broadcast!(
146146
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
147-
::Type{Broadcasted{DefaultArrayStyle{N},Nothing,F,A}},
147+
::Type{Broadcasted{S,Nothing,F,A}},
148148
elementbytes::Int = 8
149-
) where {N,F,A}
149+
) where {N,S<:Base.Broadcast.AbstractArrayStyle{N},F,A}
150150
instr = get(FUNCTIONSYMBOLS, F) do
151151
f = gensym(:f)
152152
pushpreamble!(ls, Expr(:(=), f, Expr(:(.), bcname, QuoteNode(:f))))
@@ -224,7 +224,7 @@ end
224224
# ls
225225
end
226226

227-
function vmaterialize(bc::Broadcasted)
227+
@inline function vmaterialize(bc::Broadcasted)
228228
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
229229
vmaterialize!(similar(bc, ElType), bc)
230230
end

src/lowering.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,13 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
648648
initialize_outer_reductions!(ifq, ls, Ulow, Uhigh, W, typeT, unrolled)
649649
push!(ifq.args, loopq)
650650
reduce_range!(ifq, ls, Ulow, Uhigh)
651-
comparison = Expr(:call, :!, Expr(:call, :<, unrolledloop.rangesym, Expr(:call, lv(:valmul), W, Uhigh)))
652-
Expr(:if, comparison, ifq)
651+
comparison = if unrolledloop.hintexact
652+
Expr(:call, :<, unrolledloop.rangehint, Expr(:call, lv(:valmul), W, Uhigh))
653+
else
654+
Expr(:call, :<, unrolledloop.rangesym, Expr(:call, lv(:valmul), W, Uhigh))
655+
end
656+
ncomparison = Expr(:call, :!, comparison)
657+
Expr(:if, ncomparison, ifq)
653658
end
654659
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
655660
for or ls.outer_reductions
@@ -805,10 +810,19 @@ function lower_unrolled_dynamic!(
805810
end
806811
else
807812
remblocknew = if unrolled === vectorized
808-
comparison = Expr(:call, :>, unrolled, Expr(:call, :-, unrolled_numitersym, Expr(:call, lv(:valmuladd), W, Ut, 1)))
813+
itercount = if unrolledloop.hintexact
814+
Expr(:call, :-, unrolledloop.rangehint, Expr(:call, lv(:valmuladd), W, Ut, 1))
815+
else
816+
Expr(:call, :-, unrolled_numitersym, Expr(:call, lv(:valmuladd), W, Ut, 1))
817+
end
818+
comparison = Expr(:call, :>, unrolled, itercount)
809819
Expr(Ut == 1 ? :if : :elseif, comparison, lower_set(ls, vectorized, Ut, T, W, Symbol("##mask##"), :block))
810820
else
811-
comparison = Expr(:call, :>, unrolled, Expr(:call, :-, unrolled_numitersym, Ut + 1))
821+
comparison = if unrolledloop.hintexact
822+
Expr(:call, :>, unrolled, unrolledloop.rangehint - (Ut + 1))
823+
else
824+
Expr(:call, :>, unrolled, Expr(:call, :-, unrolled_numitersym, Ut + 1))
825+
end
812826
Expr(Ut == 1 ? :if : :elseif, comparison, lower_set(ls, vectorized, Ut, T, W, nothing, :block))
813827
end
814828
push!(remblock.args, remblocknew)
@@ -824,7 +838,11 @@ function lower_unrolled_dynamic!(
824838
end
825839
Ut = 1
826840
# setup for branchy remainder calculation
827-
comparison = Expr(:call, :(!=), unrolled_numitersym, unrolled)
841+
comparison = if unrolledloop.hintexact
842+
Expr(:call, :(!=), unrolledloop.rangehint, unrolled)
843+
else
844+
Expr(:call, :(!=), unrolled_numitersym, unrolled)
845+
end
828846
remblock = Expr(:block)
829847
push!(q.args, Expr(:if, comparison, remblock))
830848
else
@@ -857,7 +875,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
857875
W = gensym(:W)
858876
typeT = gensym(:T)
859877
setup_Wmask!(ls, W, typeT, vectorized, unrolled, U)
860-
# W = VectorizationBase.pick_vector_width(ls, unrolled)
861878
tiledloop = ls.loops[tiled]
862879
static_tile = tiledloop.hintexact
863880
unrolledloop = ls.loops[unrolled]
@@ -866,8 +883,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
866883
# we build up the loop expression.
867884
Trem = Tt = T
868885
nloops = num_loops(ls);
869-
# addtileonly = sum(length, @view(oporder(ls)[:,:,:,:,end])) > 0
870-
# Texprtype = (static_tile && tiled_iter < 2T) ? :block : :while
871886
firstiter = true
872887
mangledtiled = tiledsym(tiled)
873888
local qifelse::Expr
@@ -876,7 +891,7 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
876891
lower_unrolled!(tiledloopbody, ls, vectorized, U, Tt, W, typeT, unrolledloop)
877892
tiledloopbody = lower_nest(ls, nloops, vectorized, U, Tt, tiledloopbody, 0, W, nothing, :block)
878893
if firstiter
879-
push!(q.args, (static_tile && tiled_iter < 2T) ? tiledloopbody : Expr(:while, looprange(ls, tiled, Tt, mangledtiled, tiledloop), tiledloopbody))
894+
push!(q.args, (static_tile && tiledloop.rangehint < 2T) ? tiledloopbody : Expr(:while, looprange(ls, tiled, Tt, mangledtiled, tiledloop), tiledloopbody))
880895
elseif static_tile
881896
push!(q.args, tiledloopbody)
882897
else # not static, not firstiter
@@ -887,7 +902,6 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
887902
end
888903
if static_tile
889904
if Tt == T
890-
# push!(tiledloopbody.args, Expr(:+=, mangledtiled, Tt))
891905
Texprtype = :block
892906
Tt = looprangehint(ls, tiled) % T
893907
# Recalculate U

0 commit comments

Comments
 (0)