Skip to content

Commit 61d7b85

Browse files
authored
Fix loop offsetting when loops don't start at 1. Fixes #301. (#302)
1 parent f40b6a6 commit 61d7b85

File tree

8 files changed

+367
-199
lines changed

8 files changed

+367
-199
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.51"
4+
version = "0.12.52"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -30,5 +30,5 @@ Static = "0.2"
3030
StrideArraysCore = "0.1.12"
3131
ThreadingUtilities = "0.4.5"
3232
UnPack = "1"
33-
VectorizationBase = "0.20.18"
33+
VectorizationBase = "0.20.21"
3434
julia = "1.5"

src/codegen/loopstartstopmanager.jl

Lines changed: 182 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11

2-
3-
42
function uniquearrayrefs_csesummary(ls::LoopSet)
53
uniquerefs = ArrayReferenceMeta[]
64
# each `Vector{Tuple{Int,Int}}` has the same name
@@ -195,43 +193,83 @@ function substitute_ops_all!(
195193
end
196194
end
197195
end
198-
function normalize_offsets!(
199-
ls::LoopSet, i::Int, allarrayrefs::Vector{ArrayReferenceMeta},
200-
array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
201-
)
202-
ops = operations(ls)
203-
length(ops) > 128 && return 0
204-
minoffset::Int8 = typemax(Int8)
205-
maxoffset::Int8 = typemin(Int8)
206-
# we want to store the offsets, because we don't want to require that the `offset` vectors of the variaous `ArrayReferenceMeta`s don't alias
207-
offsets::Base.RefValue{NTuple{128,Int8}} = Base.RefValue{NTuple{128,Int8}}();
208-
GC.@preserve offsets begin
209-
poffsets = Base.unsafe_convert(Ptr{Int8}, offsets)
210-
for j array_refs_with_same_name
211-
arrayref_to_name_op = arrayref_to_name_op_collection[j]
212-
for (_,__,opid) arrayref_to_name_op
213-
op = ops[opid]
214-
off = getoffsets(op.ref)[i]
215-
off == zero(Int8) && return 0
216-
minoffset = min(off, minoffset)
217-
maxoffset = max(off, maxoffset)
218-
unsafe_store!(poffsets, off, opid)
219-
end
220-
end
221-
# reaching here means none of the offsets contain `0`
222-
# we won't bother if difference between offsets is >127
223-
# we don't want `maxoffset` to overflow when subtracting `minoffset`
224-
# so we check if it's safe, and give up if it isn't
225-
(Int(maxoffset) - Int(minoffset)) > 127 && return 0
226-
for j array_refs_with_same_name
227-
arrayref_to_name_op = arrayref_to_name_op_collection[j]
228-
for (_,__,opid) arrayref_to_name_op
229-
getoffsets(ops[opid].ref)[i] = unsafe_load(poffsets, opid) - minoffset
230-
end
231-
end
232-
end
233-
return Int(minoffset)
234-
end
196+
# function normalize_offsets!(
197+
# ls::LoopSet, i::Int, allarrayrefs::Vector{ArrayReferenceMeta},
198+
# array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
199+
# )
200+
# ops = operations(ls)
201+
# length(ops) > 256 && return 0
202+
# minoffset::Int8 = typemax(Int8)
203+
# maxoffset::Int8 = typemin(Int8)
204+
# # we want to store the offsets, because we don't want to require that the `offset` vectors of the variaous `ArrayReferenceMeta`s don't alias
205+
# # loopsym = Symbol("##DUMMY##NOT#REALLY#A#LOOP##")
206+
# # stride::Int = typemin(Int)
207+
# # thereiszerooffset::Bool = false
208+
# # offsets::Base.RefValue{NTuple{128,Int8}} = Base.RefValue{NTuple{128,Int8}}();
209+
# # GC.@preserve offsets begin
210+
# # poffsets = Base.unsafe_convert(Ptr{Int8}, offsets)
211+
# for j ∈ array_refs_with_same_name
212+
# arrayref_to_name_op = arrayref_to_name_op_collection[j]
213+
# for (_,__,opid) ∈ arrayref_to_name_op
214+
# op = ops[opid]
215+
# opref = op.ref
216+
# off = getoffsets(opref)[i]
217+
# # thereiszerooffset |= off == zero(Int8)
218+
# # off == zero(Int8) && return 0
219+
# minoffset = min(off, minoffset)
220+
# maxoffset = max(off, maxoffset)
221+
# # unsafe_store!(poffsets, off, opid)
222+
# # if loopsym ≢ Symbol("##DUMMY##NOT#REALLY#A#LOOP##")
223+
# # stride = Int(getstrides(opref)[i])
224+
# # if opref.loopedindex[i]
225+
# # loopsym = getindicesonly(op)[i]
226+
# # else
227+
# # loopsym = Symbol("##NOT#A#LOOP##")
228+
# # end
229+
# # end
230+
# end
231+
# end
232+
# # reaching here means none of the offsets contain `0`
233+
# # we won't bother if difference between offsets is >127
234+
# # we don't want `maxoffset` to overflow when subtracting `minoffset`
235+
# # so we check if it's safe, and give up if it isn't
236+
# minoffsetint = Int(minoffset)
237+
# return (((Int(maxoffset) - minoffsetint) > 127)) ? 0 : minoffsetint
238+
239+
# # # if loopsym ≢ Symbol("##DUMMY##NOT#REALLY#A#LOOP##")
240+
# # # loop = getloop(ls, loopsym)
241+
# # # if minstride ≠ maxstride
242+
# # # @assert isknown(first(loop)) "Currently, if the same index is used for the same array with different multiples (e.g., `x[i]` and `x[2*i]`), then the start of that loop range must be known at compile time."
243+
# # # end
244+
# # # stride_offset = 1 - gethint(first(loop))
245+
# # # else
246+
# # # loop = first(ls.loops)
247+
# # # end
248+
# # offset_adjust = Int(minoffset)
249+
# # if (loopsym ≢ Symbol("##DUMMY##NOT#REALLY#A#LOOP##")) && (loopsym ≢ Symbol("##NOT#A#LOOP##"))
250+
# # loopstart = first(getloop(ls, loopsym))
251+
# # if isknown(loopstart)
252+
# # loopstartval = gethint(loopstart)
253+
# # offset_adjust = loopstartval*(stride - 1)
254+
# # if offset_adjust + Int(maxoffset) ≤ typemax(Int8)
255+
# # minoffset -= Int8(offset_adjust)
256+
# # stride = 1
257+
# # end
258+
# # end
259+
# # end
260+
# # for j ∈ array_refs_with_same_name
261+
# # arrayref_to_name_op = arrayref_to_name_op_collection[j]
262+
# # for (_,__,opid) ∈ arrayref_to_name_op
263+
# # new_offset = unsafe_load(poffsets, opid) - minoffset
264+
# # old_offset = getoffsets(ops[opid].ref)[i]
265+
# # @show new_offset, old_offset
266+
# # getoffsets(ops[opid].ref)[i] = new_offset
267+
# # # getoffsets(ops[opid].ref)[i] = unsafe_load(poffsets, opid) - minoffset
268+
# # end
269+
# # end
270+
# # end
271+
# # return @show offset_adjust, stride
272+
# end
235273
function isloopvalue(ls::LoopSet, ind::Symbol, isrooted::Union{Nothing,Vector{Bool}} = nothing)
236274
for (i,op) enumerate(operations(ls))
237275
if (isrooted nothing)
@@ -260,7 +298,7 @@ function cse_constant_offsets!(
260298
strides = getstrides(ar)
261299
offset = first(indices) === DISCONTIGUOUS
262300
# gespindoffsets = fill(Symbol(""), length(li))
263-
gespindsummary = Vector{Tuple{Symbol,Int}}(undef, length(li))
301+
gespindsummary = Vector{Symbol}(undef, length(li))
264302
for i eachindex(li)
265303
gespsymbol::Symbol = Symbol("")
266304
ii = i + offset
@@ -377,9 +415,9 @@ function cse_constant_offsets!(
377415
end
378416
end
379417
end
380-
constoffset = normalize_offsets!(ls, i, allarrayrefs, array_refs_with_same_name, arrayref_to_name_op_collection)
381-
gespindsummary[i] = (gespsymbol, constoffset)
382-
# pushgespind!(gespinds, ls, gespsymbol, constoffset, ind, li, i, check_shouldindbyind(ls, ind, shouldindbyind), true)
418+
# constoffset = normalize_offsets!(ls, i, allarrayrefs, array_refs_with_same_name, arrayref_to_name_op_collection)
419+
# gespindsummary[i] = (gespsymbol, constoffset)
420+
gespindsummary[i] = gespsymbol
383421
end
384422
return gespindsummary
385423
end
@@ -400,20 +438,97 @@ end
400438
# end
401439
# return nothing
402440
# end
403-
function calcgespinds(ls::LoopSet, ar::ArrayReferenceMeta, gespindsummary::Vector{Tuple{Symbol,Int}}, shouldindbyind::Vector{Bool})
441+
function adjust_offsets!(
442+
ls::LoopSet, i::Int,
443+
array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
444+
)
445+
ops = operations(ls)
446+
@assert length(ops) 256
447+
offsets::Base.RefValue{NTuple{256,Int8}} = Base.RefValue{NTuple{256,Int8}}();
448+
GC.@preserve offsets begin
449+
poffsets = Base.unsafe_convert(Ptr{Int8}, offsets)
450+
minoffset = typemax(Int8)
451+
maxoffset = typemin(Int8)
452+
# stridesunequal = false
453+
for j array_refs_with_same_name
454+
arrayref_to_name_op = arrayref_to_name_op_collection[j]
455+
for (_,__,opid) arrayref_to_name_op
456+
opref = ops[opid].ref
457+
off = getoffsets(opref)[i]
458+
minoffset = min(off, minoffset)
459+
maxoffset = max(off, maxoffset)
460+
unsafe_store!(poffsets, off, opid)
461+
# stridesunequal |= (stride ≠ getstrides(opref)[i])
462+
end
463+
end
464+
constoffset = Int(minoffset)
465+
constoffset = Core.ifelse(Int(maxoffset) - constoffset > 127, 0, constoffset)
466+
if constoffset 0
467+
for j array_refs_with_same_name
468+
arrayref_to_name_op = arrayref_to_name_op_collection[j]
469+
for (_,__,opid) arrayref_to_name_op
470+
opref = ops[opid].ref
471+
newoffset = unsafe_load(poffsets, opid) - constoffset
472+
# if stridesunequal
473+
# stride = getstrides(opref)[i]
474+
# newoffsetint = Int(newoffset) + (Int(stride) - 1)
475+
# # @assert typemin(Int8) ≤ newoffsetint ≤ typemax(Int8)
476+
# newoffset = Int8(newoffsetint)
477+
# end
478+
getoffsets(ops[opid].ref)[i] = newoffset
479+
end
480+
end
481+
end
482+
end
483+
constoffset#, Core.ifelse(stridesunequal, 1, Int(stride))
484+
end
485+
486+
function calcgespinds(
487+
ls::LoopSet, ar::ArrayReferenceMeta, gespindsummary::Vector{Symbol}, shouldindbyind::Vector{Bool},
488+
array_refs_with_same_name::Vector{Int}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
489+
)
404490
gespinds = Expr(:tuple)
405491
li = ar.loopedindex
406492
indices = getindicesonly(ar)
493+
# offsets = getoffsets(ar)
494+
strides = getstrides(ar)
407495
for i eachindex(li)
408496
ind = indices[i]
409-
gespsymbol, constoffset = gespindsummary[i]
410-
pushgespind!(gespinds, ls, gespsymbol, constoffset, ind, li[i], check_shouldindbyind(ls, ind, shouldindbyind), true)
497+
isli = li[i]
498+
gespsymbol = gespindsummary[i]
499+
# if isli & (!index_by_index) && (length(operations(ls)) ≤ 256)
500+
# ops = operations(ls)
501+
# loopfirst = first(getloop(ls, ind))
502+
# if isknown(loopfirst)
503+
# # copy in case of aliasing
504+
# end
505+
# end
506+
# constoffset ≠ 0 &&
507+
constoffset = adjust_offsets!(ls, i, array_refs_with_same_name, arrayref_to_name_op_collection)
508+
index_by_index = isli ? check_shouldindbyind(ls, ind, shouldindbyind) : true
509+
# (stridesunequal & isli) && (@assert isknown(first(getloop(ls, ind))))
510+
511+
# end
512+
# stride = strides[i]
513+
# if stride ≠ 1
514+
# loop = getloop(ls, ind)
515+
# if isknown(first(loop))
516+
# offsets[i] -= gethint(first(loop))
517+
# end
518+
# end
519+
# # for op ∈ operations(ls)
520+
# # accesses_memory(op) || continue
521+
# # sameref(op.ref, ref) || continue
522+
# # getstrides(op)
523+
# # end
524+
# end
525+
pushgespind!(gespinds, ls, gespsymbol, constoffset, Int(strides[i]), ind, isli, index_by_index, true)
411526
end
412527
gespinds
413528
end
414529

415530
function pushgespind!(
416-
gespinds::Expr, ls::LoopSet, gespsymbol::Symbol, constoffset::Int, ind::Symbol, isli::Bool, index_by_index::Bool, fromgsp::Bool
531+
gespinds::Expr, ls::LoopSet, gespsymbol::Symbol, constoffset::Int, stride::Int, ind::Symbol, isli::Bool, index_by_index::Bool, fromgsp::Bool
417532
)
418533
if isli
419534
if ind === CONSTANTZEROINDEX
@@ -453,21 +568,32 @@ function pushgespind!(
453568
loop = getloop(ls, ind)
454569
if gespsymbol === Symbol("")
455570
if isknown(first(loop))
456-
push!(gespinds.args, staticexpr(constoffset + gethint(first(loop))))
571+
# @show constoffset, gethint(first(loop))
572+
push!(gespinds.args, staticexpr(constoffset + stride*gethint(first(loop))))
457573
elseif constoffset == 0
458-
push!(gespinds.args, getsym(first(loop)))
459-
else
574+
if stride == 1
575+
push!(gespinds.args, getsym(first(loop)))
576+
else
577+
push!(gespinds.args, mulexpr(getsym(first(loop)), stride))
578+
end
579+
elseif stride == 1
460580
push!(gespinds.args, addexpr(getsym(first(loop)), constoffset))
581+
else
582+
push!(gespinds.args, addexpr(mulexpr(getsym(first(loop)), stride), constoffset))
461583
end
462584
elseif isknown(first(loop))
463-
loopfirst = gethint(first(loop)) + constoffset
585+
loopfirst = gethint(first(loop))*stride + constoffset
464586
if loopfirst == 0
465587
push!(gespinds.args, gespsymbol)
466588
else
467589
push!(gespinds.args, Expr(:call, GlobalRef(Base, :(+)), gespsymbol, staticexpr(loopfirst)))
468590
end
469591
else
470-
addedstarts = Expr(:call, GlobalRef(Base, :(+)), gespsymbol, getsym(first(loop)))
592+
addedstarts = if stride == 1
593+
Expr(:call, GlobalRef(Base, :(+)), gespsymbol, getsym(first(loop)))
594+
else
595+
Expr(:call, GlobalRef(Base, :(+)), mulexpr(stride, gespsymbol), getsym(first(loop)))
596+
end
471597
if constoffset == 0
472598
push!(gespinds.args, addedstarts)
473599
else
@@ -563,25 +689,25 @@ function use_loop_induct_var!(
563689
if !li[i] # if it wasn't set
564690
uliv[i] = 0
565691
push!(offsetprecalc_descript.args, 0)
566-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
692+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, 1, ind, isli, true, false)
567693
elseif ind === CONSTANTZEROINDEX
568694
uliv[i] = 0
569695
push!(offsetprecalc_descript.args, 0)
570-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
696+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, 1, ind, isli, true, false)
571697
elseif isbroadcast ||
572698
((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) ||
573699
multiple_with_name(vptrar, allarrayrefs)) ||
574700
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
575701
# Not doing normal offset indexing
576702
uliv[i] = -findfirst(Base.Fix2(===,ind), looporder)::Int
577703
push!(offsetprecalc_descript.args, 0) # not doing offset indexing, so push 0
578-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
704+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, 1, ind, isli, true, false)
579705
else
580706
uliv[i] = findfirst(Base.Fix2(===,ind), looporder)::Int
581707
loop = getloop(ls, ind)
582708
push!(offsetprecalc_descript.args, max(5,us.u₁+1,us.u₂+1))
583709
use_offsetprecalc = true
584-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, false, false)
710+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, 1, ind, isli, false, false)
585711
end
586712
# cases for pushgespind! and loopval!
587713
# if !isloopval, same as before

0 commit comments

Comments
 (0)