Skip to content

Commit de9de65

Browse files
committed
Switch to LLVMPtr
1 parent 525d649 commit de9de65

File tree

9 files changed

+24
-33
lines changed

9 files changed

+24
-33
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, cache_si
66
mask, pick_vector_width, MM, AbstractMask, data, grouped_strided_pointer,
77
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, vzero,
88
maybestaticrange, offsetprecalc, lazymul,
9-
maybestaticfirst, maybestaticlast, scalar_less, scalar_greaterequal, gep, gesp, pointerforcomparison, NativeTypes,
9+
maybestaticfirst, maybestaticlast, scalar_less, scalar_greaterequal, gep, gesp, llvmptr, NativeTypes,
1010
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd_fast, vfmsub_fast, vfnmadd_fast, vfnmsub_fast, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231,
1111
vfma_fast, vmuladd_fast, vdiv_fast, vadd_fast, vsub_fast, vmul_fast,
1212
relu, stridedpointer, StridedPointer, StridedBitPointer, AbstractStridedPointer, _vload, _vstore!,

src/broadcast.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ Base.@propagate_inbounds Base.getindex(A::LowDimArray, i::Vararg{Union{Integer,C
162162
@inline Base.strides(A::LowDimArray) = strides(A.data)
163163
@inline ArrayInterface.parent_type(::Type{LowDimArray{D,T,N,A}}) where {T,D,N,A} = A
164164
@inline ArrayInterface.strides(A::LowDimArray) = ArrayInterface.strides(A.data)
165+
@inline ArrayInterface.device(::LowDimArray) = ArrayInterface.CPUPointer()
165166
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
166167
t = Expr(:tuple)
167168
for n 1:N
@@ -206,7 +207,7 @@ end
206207
Nnew += 1
207208
end
208209
typ = Expr(:curly, :StridedPointer, T, Nnew, Cnew, Bnew, Rtup)
209-
ptr = Expr(:call, typ, :(pointer(p)), strd, offsets)
210+
ptr = Expr(:call, typ, Expr(:call, lv(:llvmptr), :p), strd, offsets)
210211
Expr(:block, Expr(:meta,:inline), :(strd = p.strd), :(offs = p.offsets), ptr)
211212
end
212213
# @generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}

src/codegen/loopstartstopmanager.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
newB = C > 0 ? (C == minrank ? B : 0) : B #TODO: confirm correctness
6666
quote
6767
$(Expr(:meta,:inline))
68-
VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}(pointer(sptr), (sptr.strd[$minrank],), (Zero(),))
68+
VectorizationBase.StridedPointer{$T,1,$newC,$newB,$(R[minrank],)}($(lv(llvmptr))(sptr), (sptr.strd[$minrank],), (Zero(),))
6969
end
7070
end
7171
set_first_stride(x) = x # cross fingers that this works
@@ -310,7 +310,7 @@ function pointermax_index(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int,
310310
end
311311
function pointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool, stopsym, incr::MaybeKnown)::Expr
312312
index = first(pointermax_index(ls, ar, n, sub, isvectorized, stopsym, incr))
313-
Expr(:call, lv(:pointerforcomparison), vptr(ar), index)
313+
Expr(:call, lv(:gesp), vptr(ar), index)
314314
end
315315

316316
function defpointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, isvectorized::Bool)::Expr
@@ -347,16 +347,14 @@ function append_pointer_maxes!(
347347
# @show n, getloop(ls, n) ar
348348
index, ind = pointermax_index(ls, ar, n, submax, isvectorized, stopindicator, incr)
349349
vptr_ar = vptr(ar)
350-
_pointercompbase = maxsym(vptr_ar, submax)
351-
pointercompbase = gensym(_pointercompbase)
350+
pointercompbase = maxsym(vptr_ar, submax)
352351
push!(loopstart.args, Expr(:(=), pointercompbase, Expr(:call, lv(:gesp), vptr_ar, index)))
353-
push!(loopstart.args, Expr(:(=), _pointercompbase, Expr(:call, lv(:pointerforcomparison), pointercompbase)))
354352
dim = length(getindicesonly(ar))
355353
# OFFSETPRECALCDEF = true
356354
# if OFFSETPRECALCDEF
357355
strd = getstrides(ar)[dim]
358356
for sub 0:submax-1
359-
ptrcmp = Expr(:call, lv(:pointerforcomparison), pointercompbase, offsetindex(dim, ind, (submax - sub)*strd, isvectorized, incr))
357+
ptrcmp = Expr(:call, lv(:gesp), pointercompbase, offsetindex(dim, ind, (submax - sub)*strd, isvectorized, incr))
360358
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), ptrcmp))
361359
end
362360
# else
@@ -466,9 +464,9 @@ function terminatecondition(ls::LoopSet, us::UnrollSpecification, n::Int, inclma
466464
ptr = vptr(termar)
467465
# @show UF, isvectorized(us, n)
468466
if inclmask && isvectorized(us, n)
469-
Expr(:call, :<, callpointerforcomparison(ptr), maxsym(ptr, 0))
467+
Expr(:call, :<, ptr, maxsym(ptr, 0))
470468
else
471-
Expr(:call, :, callpointerforcomparison(ptr), maxsym(ptr, UF))
469+
Expr(:call, :, ptr, maxsym(ptr, UF))
472470
end
473471
end
474472

src/codegen/lower_threads.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function (::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V})(p::Ptr{UInt}) where {UNROLL,OPS,A
77
(_, _vargs) = ThreadingUtilities.load(p, Tuple{LB,V}, 2*sizeof(UInt))
88
# Main.VARGS[Threads.threadid()] = first(_vargs)
99
ret = _avx_!(Val{UNROLL}(), Val{OPS}(), Val{ARF}(), Val{AM}(), Val{LPSYM}(), _vargs)
10-
ThreadingUtilities.store!(p, ret, 64)
10+
ThreadingUtilities.store!(p, ret, Int(register_size()))
1111
nothing
1212
end
1313
@generated function Base.pointer(::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
@@ -28,16 +28,11 @@ function launch(
2828
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, lb::LB, vargs::V, tid
2929
) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
3030
p = ThreadingUtilities.taskpointer(tid)
31-
f = AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()
32-
fptr = pointer(f)
31+
launch!(p, pointer(AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()), (lb,vargs))
3332
while true
34-
if ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.SPIN, ThreadingUtilities.STUP)
35-
launch!(p, fptr, (lb,vargs))
36-
@assert ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.STUP, ThreadingUtilities.TASK)
33+
if ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.SPIN, ThreadingUtilities.TASK)
3734
return
38-
elseif ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.WAIT, ThreadingUtilities.STUP)
39-
launch!(p, fptr, (lb,vargs))
40-
@assert ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.STUP, ThreadingUtilities.LOCK)
35+
elseif ThreadingUtilities._atomic_cas_cmp!(p, ThreadingUtilities.WAIT, ThreadingUtilities.LOCK)
4136
ThreadingUtilities.wake_thread!(tid % UInt)
4237
return
4338
end

src/codegen/lowering.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,10 @@ function pointerremcomparison(ls::LoopSet, termind::Int, UFt::Int, n::Int, nisve
425425
termar = lssm.incrementedptrs[n][termind]
426426
ptrdef = lssm.incrementedptrs[n][termind]
427427
ptr = vptr(termar)
428-
ptrex = callpointerforcomparison(ptr)
429428
if remfirst
430-
Expr(:call, :<, ptrex, pointermax(ls, ptrdef, n, 1 - UFt, nisvectorized, loop))
429+
Expr(:call, :<, ptr, pointermax(ls, ptrdef, n, 1 - UFt, nisvectorized, loop))
431430
else
432-
# Expr(:call, :≥, ptrex, pointermax(ls, ptrdef, n, UFt, nisvectorized, loop))
433-
Expr(:call, :, ptrex, maxsym(ptr, UFt))
431+
Expr(:call, :, ptr, maxsym(ptr, UFt))
434432
end
435433
end
436434

@@ -657,7 +655,7 @@ function init_remblock(unrolledloop::Loop, lssm::LoopStartStopManager, n::Int)#u
657655
else
658656
termar = lssm.incrementedptrs[n][termind]
659657
ptr = vptr(termar)
660-
condition = Expr(:call, :<, callpointerforcomparison(ptr), maxsym(ptr, 0))
658+
condition = Expr(:call, :<, vptr(ptr), maxsym(ptr, 0))
661659
end
662660
Expr(:if, condition)
663661
end

src/modeling/graphs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ function addexpr(ex, incr::Integer)
201201
end
202202

203203
staticmulincr(ptr, incr) = Expr(:call, lv(:staticmul), Expr(:call, :eltype, ptr), incr)
204-
callpointerforcomparison(sym) = Expr(:call, lv(:pointerforcomparison), sym)
205204
function vec_looprange(loop::Loop, UF::Int, mangledname)
206205
compexpr = Expr(:call, lv(:vsub_fast))
207206
pushexpr!(compexpr, last(loop))

src/simdfunctionals/filter.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ function vfilter!(f::F, x::Vector{T}, y::AbstractArray{T}) where {F,T <: NativeT
88
st = VectorizationBase.static_sizeof(T)
99
zero_index = MM(W, Static(0), st)
1010
GC.@preserve x y begin
11-
ptr_x = pointer(x)
12-
ptr_y = pointer(y)
11+
ptr_x = llvmptr(x)
12+
ptr_y = llvmptr(y)
1313
for _ 1:Nrep
14-
vy = vload(ptr_y, zero_index)
14+
vy = VectorizationBase.__vload(ptr_y, zero_index, False(), register_size())
1515
mask = f(vy)
1616
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
1717
ptr_y = gep(ptr_y, register_size())
1818
j = vadd_fast(j, count_ones(mask))
1919
end
2020
rem_mask = VectorizationBase.mask(T, Nrem)
21-
vy = vload(ptr_y, zero_index, rem_mask)
21+
vy = VectorizationBase.__vload(ptr_y, zero_index, rem_mask, False(), register_size())
2222
mask = rem_mask & f(vy)
2323
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
2424
j = vadd_fast(j, count_ones(mask))

src/simdfunctionals/map.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107
# :(_vmap_thread_call!($(F.instance), p, $D, $A, Val{$NonTemporal}()))
108108
# end
109109
function (m::VmapClosure{NonTemporal,F,D,N,A})(p::Ptr{UInt}) where {NonTemporal,F,D,N,A}
110-
(offset, dest) = ThreadingUtilities.load(p, D, 1)
110+
(offset, dest) = ThreadingUtilities.load(p, D, 2*sizeof(UInt))
111111
(offset, args) = ThreadingUtilities.load(p, A, offset)
112112

113113
(offset, start) = ThreadingUtilities.load(p, Int, offset)
@@ -132,7 +132,7 @@ end
132132
p, cfunc, ptry, ptrargs, start, stop
133133
)
134134
fptr = _get_fptr(cfunc)
135-
offset = ThreadingUtilities.store!(p, fptr, 0)
135+
offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt))
136136
offset = ThreadingUtilities.store!(p, ptry, offset)
137137
offset = ThreadingUtilities.store!(p, ptrargs, offset)
138138
offset = ThreadingUtilities.store!(p, start, offset)

src/simdfunctionals/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ function mapreduce_simple(f::F, op::OP, args::Vararg{AbstractArray,A}) where {F,
1919
N = length(first(args))
2020
iszero(N) && throw("Length of vector is 0!")
2121
st = ntuple(a -> VectorizationBase.static_sizeof(eltype(args[a])), Val(A))
22-
a_0 = f(vload.(ptrargs)...); i = 1
22+
a_0 = f(VectorizationBase.__vload.(ptrargs, False(), register_size())...); i = 1
2323
while i < N
24-
a_0 = op(a_0, f(vload.(ptrargs, VectorizationBase.lazymul.(st, i))...)); i += 1
24+
a_0 = op(a_0, f(VectorizationBase.vload.(ptrargs, VectorizationBase.lazymul.(st, i), False(), register_size())...)); i += 1
2525
end
2626
a_0
2727
end

0 commit comments

Comments
 (0)