Skip to content

Commit 5f8d5c1

Browse files
committed
More efficient pointer comparisons
1 parent 4b41e67 commit 5f8d5c1

File tree

11 files changed

+206
-168
lines changed

11 files changed

+206
-168
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.16"
4+
version = "0.12.17"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -30,7 +30,7 @@ Static = "0.2"
3030
StrideArraysCore = "0.1.5"
3131
ThreadingUtilities = "0.4.1"
3232
UnPack = "1"
33-
VectorizationBase = "0.19.34"
33+
VectorizationBase = "0.19.35"
3434
julia = "1.5"
3535

3636
[extras]

src/codegen/loopstartstopmanager.jl

Lines changed: 91 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
3+
14
function uniquearrayrefs_csesummary(ls::LoopSet)
25
uniquerefs = ArrayReferenceMeta[]
36
# each `Vector{Tuple{Int,Int}}` has the same name
@@ -100,21 +103,21 @@ function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMe
100103
out
101104
end
102105

103-
@generated function set_first_stride(sptr::StridedPointer{T,N,C,B,R}) where {T,N,C,B,R}
104-
minrank = argmin(R)
105-
newC = C > 0 ? (C == minrank ? 1 : 0) : C
106-
newB = C > 0 ? (C == minrank ? B : 0) : B #TODO: confirm correctness
107-
quote
108-
$(Expr(:meta,:inline))
109-
# VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}($(lv(llvmptr))(sptr), (sptr.strd[$minrank],), (Zero(),))
110-
VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}(VectorizationBase.cpupointer(sptr), (sptr.strd[$minrank],), (Zero(),))
111-
end
112-
end
113-
set_first_stride(x) = x # cross fingers that this works
114-
@inline onetozeroindexgephack(sptr::AbstractStridedPointer) = gesp(set_first_stride(sptr), (Static{-1}(),)) # go backwords
115-
@inline onetozeroindexgephack(sptr::AbstractStridedPointer{T,1}) where {T} = sptr
106+
# @generated function set_first_stride(sptr::StridedPointer{T,N,C,B,R}) where {T,N,C,B,R}
107+
# minrank = argmin(R)
108+
# newC = C > 0 ? (C == minrank ? 1 : 0) : C
109+
# newB = C > 0 ? (C == minrank ? B : 0) : B #TODO: confirm correctness
110+
# quote
111+
# $(Expr(:meta,:inline))
112+
# # VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}($(lv(llvmptr))(sptr), (sptr.strd[$minrank],), (Zero(),))
113+
# VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}(VectorizationBase.cpupointer(sptr), (sptr.strd[$minrank],), (Zero(),))
114+
# end
115+
# end
116+
# set_first_stride(x) = x # cross fingers that this works
117+
# @inline onetozeroindexgephack(sptr::AbstractStridedPointer) = gesp(set_first_stride(sptr), (Static{-1}(),)) # go backwords
118+
# @inline onetozeroindexgephack(sptr::AbstractStridedPointer{T,1}) where {T} = sptr
116119
# @inline onetozeroindexgephack(sptr::StridedPointer{T,1}) where {T} = sptr
117-
@inline onetozeroindexgephack(x) = x
120+
# @inline onetozeroindexgephack(x) = x
118121

119122
# # Removes parent/child relationship for all children with ref `ar`
120123
# function freechildren!(op::Operation, ar::ArrayReferenceMeta)
@@ -586,6 +589,7 @@ function use_loop_induct_var!(
586589
vpgesped = Expr(:call, lv(:offsetprecalc), vpgesped, Expr(:call, Expr(:curly, :Val, offsetprecalc_descript)))
587590
end
588591
push!(q.args, Expr(:(=), vptrar, vpgesped))
592+
push!(q.args, Expr(:(=), vptr_offset(vptrar), Expr(:call, GlobalRef(VectorizationBase, :increment_ptr), vptrar)))
589593
end
590594
uliv
591595
end
@@ -654,7 +658,7 @@ function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvec
654658
stop = last(loop)
655659
incr = step(loop)
656660
if isknown(start) & isknown(stop)
657-
pointermax(ls, ar, n, sub, isvectorized, 1 + gethint(stop) - gethint(start), incr)
661+
return pointermax(ls, ar, n, sub, isvectorized, 1 + gethint(stop) - gethint(start), incr)
658662
end
659663
looplensym = isone(start) ? getsym(stop) : loop.lensym
660664
pointermax(ls, ar, n, sub, isvectorized, looplensym, incr)
@@ -740,8 +744,9 @@ function pointermax_index(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int,
740744
index, ind
741745
end
742746
function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool, stopsym, incr::MaybeKnown)::Expr
743-
index = first(pointermax_index(ls, ar, n, sub, isvectorized, stopsym, incr))
744-
Expr(:call, lv(:gesp), vptr(ar), index)
747+
index = first(pointermax_index(ls, ar, n, sub, isvectorized, stopsym, incr))
748+
vptrar = vptr(ar)
749+
Expr(:call, GlobalRef(VectorizationBase,:increment_ptr), vptrar, vptr_offset(vptrar), index)
745750
end
746751

747752
function defpointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool)::Expr
@@ -767,58 +772,48 @@ end
767772
function append_pointer_maxes!(
768773
loopstart::Expr, ls::LoopSet, ar::ArrayReferenceMeta, n::Int, submax::Int, isvectorized::Bool, stopindicator, incr::MaybeKnown
769774
)
770-
vptr_ar = vptr(ar)
771-
if submax < 2
772-
for sub 0:submax
773-
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), pointermax(ls, ar, n, sub, isvectorized, stopindicator, incr)))
774-
end
775-
else
776-
index, ind = pointermax_index(ls, ar, n, submax, isvectorized, stopindicator, incr)
777-
pointercompbase = maxsym(vptr_ar, submax)
778-
push!(loopstart.args, Expr(:(=), pointercompbase, Expr(:call, lv(:gesp), vptr_ar, index)))
779-
dim = length(getindicesonly(ar))
780-
# OFFSETPRECALCDEF = true
781-
# if OFFSETPRECALCDEF
782-
strd = getstrides(ar)[ind]
783-
for sub 0:submax-1
784-
ptrcmp = Expr(:call, lv(:gesp), pointercompbase, offsetindex(dim, ind, (submax - sub)*strd, isvectorized, incr))
785-
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), ptrcmp))
786-
end
787-
# else
788-
# indexoff = offsetindex(dim, ind, 1, isvectorized)
789-
# for sub ∈ submax-1:-1:0
790-
# _newpointercompbase = maxsym(vptr_ar, sub)
791-
# newpointercompbase = gensym(_pointercompbase)
792-
# push!(loopstart.args, Expr(:(=), newpointercompbase, Expr(:call, lv(:gesp), pointercompbase, indexoff)))
793-
# push!(loopstart.args, Expr(:(=), _newpointercompbase, Expr(:call, lv(:pointerforcomparison), newpointercompbase)))
794-
# _pointercompbase = _newpointercompbase
795-
# pointercompbase = newpointercompbase
796-
# end
797-
# end
775+
vptr_ar = vptr(ar)
776+
if submax < 2
777+
for sub 0:submax
778+
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), pointermax(ls, ar, n, sub, isvectorized, stopindicator, incr)))
798779
end
780+
else
781+
index, ind = pointermax_index(ls, ar, n, submax, isvectorized, stopindicator, incr)
782+
pointercompbase = maxsym(vptr_ar, submax)
783+
ip = GlobalRef(VectorizationBase, :increment_ptr)
784+
push!(loopstart.args, Expr(:(=), pointercompbase, Expr(:call, ip, vptr_ar, vptr_offset(vptr_ar), index)))
785+
dim = length(getindicesonly(ar))
786+
# OFFSETPRECALCDEF = true
787+
# if OFFSETPRECALCDEF
788+
strd = getstrides(ar)[ind]
789+
for sub 0:submax-1
790+
ptrcmp = Expr(:call, ip, vptr_ar, pointercompbase, offsetindex(dim, ind, (submax - sub)*strd, isvectorized, incr))
791+
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), ptrcmp))
792+
end
793+
end
799794
end
800795
function append_pointer_maxes!(loopstart::Expr, ls::LoopSet, ar::ArrayReferenceMeta, n::Int, submax::Int, isvectorized::Bool)
801-
loop = getloop(ls, n)
802-
@assert loop.itersymbol == names(ls)[n]
803-
start = first(loop)
804-
stop = last(loop)
805-
incr = step(loop)
806-
if isknown(start) & isknown(stop)
807-
return append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, startstopΔ(loop)+1, incr)
808-
end
809-
looplensym = isone(start) ? getsym(stop) : loop.lensym
810-
append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, looplensym, incr)
796+
loop = getloop(ls, n)
797+
@assert loop.itersymbol == names(ls)[n]
798+
start = first(loop)
799+
stop = last(loop)
800+
incr = step(loop)
801+
if isknown(start) & isknown(stop)
802+
return append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, startstopΔ(loop)+1, incr)
803+
end
804+
looplensym = isone(start) ? getsym(stop) : loop.lensym
805+
append_pointer_maxes!(loopstart, ls, ar, n, submax, isvectorized, looplensym, incr)
811806
end
812807

813808
function maxunroll(us::UnrollSpecification, n)
814-
@unpack u₁loopnum, u₂loopnum, u₁, u₂ = us
815-
if n == u₁loopnum# && u₁ > 1
816-
u₁
817-
elseif n == u₂loopnum# && u₂ > 1
818-
u₂
819-
else
820-
1
821-
end
809+
@unpack u₁loopnum, u₂loopnum, u₁, u₂ = us
810+
if n == u₁loopnum# && u₁ > 1
811+
u₁
812+
elseif n == u₂loopnum# && u₂ > 1
813+
u₂
814+
else
815+
1
816+
end
822817
end
823818

824819

@@ -830,8 +825,8 @@ function startloop(ls::LoopSet, us::UnrollSpecification, n::Int, submax = maxunr
830825
loopstart = Expr(:block)
831826
firstloop = n == num_loops(ls)
832827
for ar ptrdefs
833-
ptr = vptr(ar)
834-
push!(loopstart.args, Expr(:(=), ptr, ptr))
828+
ptr_offset = vptr_offset(ar)
829+
push!(loopstart.args, Expr(:(=), ptr_offset, ptr_offset))
835830
end
836831
if iszero(termind)
837832
loopsym = names(ls)[n]
@@ -845,22 +840,24 @@ end
845840
function offset_ptr(
846841
ar::ArrayReferenceMeta, us::UnrollSpecification, loopsym::Symbol, n::Int, UF::Int, offsetinds::Vector{Bool}, loop::Loop
847842
)
848-
indices = getindices(ar)
849-
strides = getstrides(ar)
850-
offset = first(indices) === DISCONTIGUOUS
851-
gespinds = Expr(:tuple)
852-
li = ar.loopedindex
853-
for i eachindex(li)
854-
ii = i + offset
855-
ind = indices[ii]
856-
if !offsetinds[i] || ind !== loopsym
857-
push!(gespinds.args, Expr(:call, lv(:Zero)))
858-
else
859-
incrementloopcounter!(gespinds, us, n, UF * strides[i], loop)
860-
end
861-
# ind == loopsym && break
843+
indices = getindices(ar)
844+
strides = getstrides(ar)
845+
offset = first(indices) === DISCONTIGUOUS
846+
gespinds = Expr(:tuple)
847+
li = ar.loopedindex
848+
for i eachindex(li)
849+
ii = i + offset
850+
ind = indices[ii]
851+
if !offsetinds[i] || ind !== loopsym
852+
push!(gespinds.args, Expr(:call, lv(:Zero)))
853+
else
854+
incrementloopcounter!(gespinds, us, n, UF * strides[i], loop)
862855
end
863-
Expr(:(=), vptr(ar), Expr(:call, lv(:gesp), vptr(ar), gespinds))
856+
# ind == loopsym && break
857+
end
858+
vpoff = vptr_offset(ar)
859+
call = Expr(:call, GlobalRef(VectorizationBase, :increment_ptr), vptr(ar), vpoff, gespinds)
860+
Expr(:(=), vpoff, call)
864861
end
865862
function incrementloopcounter!(q::Expr, ls::LoopSet, us::UnrollSpecification, n::Int, UF::Int)
866863
@unpack u₁loopnum, u₂loopnum, vloopnum, u₁, u₂ = us
@@ -880,18 +877,19 @@ function incrementloopcounter!(q::Expr, ls::LoopSet, us::UnrollSpecification, n:
880877
nothing
881878
end
882879
function terminatecondition(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool, UF::Int)
883-
lssm = ls.lssm
884-
termind = lssm.terminators[n]
885-
if iszero(termind)
886-
loop = getloop(ls, n)
887-
return terminatecondition(loop, us, n, loop.itersymbol, inclmask, UF)
888-
end
880+
lssm = ls.lssm
881+
termind = lssm.terminators[n]
882+
if iszero(termind)
883+
loop = getloop(ls, n)
884+
return terminatecondition(loop, us, n, loop.itersymbol, inclmask, UF)
885+
end
889886

890-
termar = lssm.incrementedptrs[n][termind]
891-
ptr = vptr(termar)
892-
if inclmask && isvectorized(us, n)
893-
Expr(:call, :<, ptr, maxsym(ptr, 0))
894-
else
895-
Expr(:call, :, ptr, maxsym(ptr, UF))
896-
end
887+
termar = lssm.incrementedptrs[n][termind]
888+
ptr = vptr(termar)
889+
optr = vptr_offset(ptr)
890+
if inclmask && isvectorized(us, n)
891+
Expr(:call, GlobalRef(VectorizationBase, :vlt), optr, maxsym(ptr, 0), ptr)
892+
else
893+
Expr(:call, GlobalRef(VectorizationBase, :vle), optr, maxsym(ptr, UF), ptr)
894+
end
897895
end

src/codegen/lower_load.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
7878
# gespinds.args[i] = Expr(:call, lv(:data), gespinds.args[i])
7979
end
8080
end
81-
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
82-
81+
ip = GlobalRef(VectorizationBase, :increment_ptr)
82+
push!(q.args, Expr(:(=), gptr, Expr(:call, ip, ptr, vptr_offset(ptr), gespinds)))
8383
inds = Expr(:tuple)
8484
indices = getindicesonly(op)
8585

@@ -88,7 +88,9 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
8888
push!(inds.args, Expr(:call, lv(:Zero)))
8989
(ind == u₁loopsym) && (i = j)
9090
end
91-
push!(q.args, Expr(:call, lv(:prefetch0), gptr, copy(inds)))
91+
prefetch0 = GlobalRef(VectorizationBase, :prefetch)
92+
push!(q.args, Expr(:call, prefetch0, Expr(:call, ip, ptr, gptr, copy(inds))))
93+
# push!(q.args, Expr(:call, lv(:prefetch0), gptr, copy(inds)))
9294
i == 0 && return
9395
for u 1:u₁-1
9496
# for u ∈ umin:min(umin,U-1)
@@ -107,14 +109,15 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
107109
else
108110
inds.args[i] = staticexpr(u)
109111
end
110-
push!(q.args, Expr(:call, lv(:prefetch0), gptr, copy(inds)))
112+
push!(q.args, Expr(:call, prefetch0, Expr(:call, ip, ptr, gptr, copy(inds))))
111113
end
112114
nothing
113115
end
114116
broadcastedname(mvar) = Symbol(mvar, "##broadcasted##")
115117
function pushbroadcast!(q::Expr, mvar::Symbol)
116118
push!(q.args, Expr(:(=), broadcastedname(mvar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, mvar)))
117119
end
120+
118121
function lower_load_no_optranslation!(
119122
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
120123
)
@@ -127,15 +130,16 @@ function lower_load_no_optranslation!(
127130
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
128131
if all(op.ref.loopedindex) && !rejectcurly(op)
129132
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset, ls)
130-
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
133+
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
131134
add_memory_mask!(loadexpr, op, td, mask, ls)
132135
push!(loadexpr.args, falseexpr, rs) # unaligned load
133136
push!(q.args, Expr(:(=), mvar, loadexpr))
134137
elseif (u₁ > 1) & opu₁
135138
t = Expr(:tuple)
139+
sptrsym = sptr!(q, op)
136140
for u 1:u₁
137141
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, u-1, ls)
138-
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
142+
loadexpr = Expr(:call, lv(:_vload), sptrsym, inds)
139143
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
140144
add_memory_mask!(loadexpr, op, td, domask, ls)
141145
push!(loadexpr.args, falseexpr, rs)
@@ -145,7 +149,7 @@ function lower_load_no_optranslation!(
145149
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
146150
else
147151
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, 0, ls)
148-
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
152+
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
149153
add_memory_mask!(loadexpr, op, td, mask, ls)
150154
push!(loadexpr.args, falseexpr, rs)
151155
push!(q.args, Expr(:(=), mvar, loadexpr))
@@ -224,7 +228,8 @@ function lower_load_for_optranslation!(
224228
# gespinds.args[i] = Expr(:call, lv(:unmm), gespinds.args[i])
225229
end
226230
end
227-
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
231+
ip = GlobalRef(VectorizationBase, :increment_ptr)
232+
push!(q.args, Expr(:(=), vptr_offset(gptr), Expr(:call, ip, ptr, vptr_offset(ptr), gespinds)))
228233
fill!(inds_by_ptroff, true)
229234
@unpack ref, loopedindex = mref
230235
indices = copy(getindices(ref))
@@ -452,8 +457,8 @@ function lower_load_collection!(
452457
false
453458
end
454459
uinds = Expr(:call, unrollcurl₂, inds)
455-
vp = vptr(op)
456-
loadexpr = Expr(:call, lv(:_vload), vp, uinds)
460+
sptrsym = sptr!(q, op)
461+
loadexpr = Expr(:call, lv(:_vload), sptrsym, uinds)
457462
# not using `add_memory_mask!(storeexpr, op, ua, mask, ls)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
458463
u₁vectorized = u₁loopsym === vloopsym
459464
if (mask && isvectorized(op))
@@ -462,7 +467,7 @@ function lower_load_collection!(
462467
end
463468
end
464469
push!(loadexpr.args, falseexpr, rs)
465-
collectionname = Symbol(vp, "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
470+
collectionname = Symbol(vptr(op), "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
466471
gf = GlobalRef(Core,:getfield)
467472
if manualunrollu₁
468473
masklast = mask & u₁vectorized & isvectorized(op)

src/codegen/lower_memory_common.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,17 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
172172
end
173173
ret
174174
end
175+
function sptr(op::Operation)
176+
vp = vptr(op)
177+
Expr(:call, GlobalRef(VectorizationBase, :reconstruct_ptr), vp, vptr_offset(vp))
178+
end
179+
function sptr!(q::Expr, op::Operation)
180+
vp = vptr(op)
181+
sptrsym = gensym(vp)
182+
push!(q.args, Expr(:(=), sptrsym, sptr(op)))
183+
sptrsym
184+
end
185+
175186
# function unrolled_curly(op::Operation, u₁::Int, u₁loopsym::Symbol, vectorized::Symbol, mask::Bool)
176187

177188
# interleave: `0` means `false`, positive means literal, negative means multiplier
@@ -183,7 +194,6 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
183194
li = op.ref.loopedindex
184195
# @assert all(loopedindex)
185196
# @unpack u₁, u₁loopsym, vloopsym = td
186-
# @show vptr(op), inds_calc_by_ptr_offset
187197
AV = AU = -1
188198
for (n,ind) enumerate(indices)
189199
# @show AU, op, n, ind, vloopsym, u₁loopsym

0 commit comments

Comments
 (0)