Skip to content

Commit 70a632f

Browse files
committed
A few more tests pass.
1 parent e05ffca commit 70a632f

File tree

10 files changed

+92
-98
lines changed

10 files changed

+92
-98
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ using VectorizationBase, SLEEFPirates, UnPack, OffsetArrays
88
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, data,
99
mask, pick_vector_width_val, MM,
1010
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, vzero,
11-
Zero, maybestaticrange, offsetprecalc,
11+
Zero, maybestaticrange, offsetprecalc, lazymul,
1212
maybestaticfirst, maybestaticlast, scalar_less, gep, gesp, pointerforcomparison, NativeTypes,
1313
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, vadd, vsub, vmul,
1414
relu, stridedpointer, StridedPointer, AbstractStridedPointer,
1515
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min,
16-
vsum, vprod, vmaximum, vminimum
16+
vsum, vprod, vmaximum, vminimum, vstorent!
1717

1818
using IfElse: ifelse
1919

src/filter.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ if (Base.libllvm_version ≥ v"7" && VectorizationBase.AVX512F) || Base.libllvm_
77
Nrem = N & (W - 1)
88
j = 0
99
st = VectorizationBase.static_sizeof(T)
10+
zero_index = MM{W}(Static(0), st)
1011
GC.@preserve x y begin
1112
ptr_x = pointer(x)
1213
ptr_y = pointer(y)
1314
for _ 1:Nrep
14-
vy = vload(ptr_y, MM{W}(Static(0), st))
15+
vy = vload(ptr_y, zero_index)
1516
mask = f(vy)
1617
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
1718
ptr_y = gep(ptr_y, VectorizationBase.REGISTER_SIZE)
1819
j = vadd(j, count_ones(mask))
1920
end
2021
rem_mask = VectorizationBase.mask(T, Nrem)
21-
vy = vload(ptr_y, MM{W}(Static(0), st), rem_mask)
22+
vy = vload(ptr_y, zero_index, rem_mask)
2223
mask = rem_mask & f(vy)
2324
VectorizationBase.compressstore!(gep(ptr_x, VectorizationBase.lazymul(st, j)), vy, mask)
2425
j = vadd(j, count_ones(mask))

src/graphs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ function vec_looprange(loopmax, UF::Int, mangledname::Symbol, ptrcomp::Bool)
106106
end
107107
function vec_looprange(loopmax, UF::Int, mangledname, W)
108108
incr = if isone(UF)
109-
Expr(:call, lv(:vsub), W, :(Static{1}()))
109+
Expr(:call, lv(:vsub), W, staticexpr(1))
110110
else
111-
Expr(:call, lv(:vsub), Expr(:call, lv(:vmul), W, UF), :(Static{1}()))
111+
Expr(:call, lv(:vsub), Expr(:call, lv(:vmul), W, UF), staticexpr(1))
112112
end
113113
compexpr = subexpr(loopmax, incr)
114114
Expr(:call, :<, mangledname, compexpr)
@@ -142,7 +142,7 @@ function incrementloopcounter(us::UnrollSpecification, n::Int, mangledname::Symb
142142
if isone(UF)
143143
Expr(:(=), mangledname, Expr(:call, lv(:vadd), VECTORWIDTHSYMBOL, mangledname))
144144
else
145-
Expr(:(=), mangledname, Expr(:call, lv(:vadd), Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, :(Static{$UF}())), mangledname))
145+
Expr(:(=), mangledname, Expr(:call, lv(:vadd), Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, staticexpr(UF)), mangledname))
146146
end
147147
else
148148
Expr(:(=), mangledname, Expr(:call, lv(:vadd), mangledname, UF))
@@ -156,7 +156,7 @@ function incrementloopcounter!(q, us::UnrollSpecification, n::Int, UF::Int = unr
156156
push!(q.args, Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, Expr(:call, Expr(:curly, lv(:Static), UF))))
157157
end
158158
else
159-
push!(q.args, Expr(:call, Expr(:curly, lv(:Static), UF)))
159+
push!(q.args, staticexpr(UF))
160160
end
161161
end
162162
function looplengthexpr(loop::Loop)

src/lower_memory_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function add_vectorized_offset_unrolled!(ret::Expr, offset, incr)
111111
push!(ret.args, _MMind(staticexpr(offset)))
112112
end
113113
elseif iszero(offset)
114-
push!(ret.args, _MMind(Expr(:call, lv(:staticmul), VECTORWIDTHSYMBOL, maybestatic(incr))))
114+
push!(ret.args, _MMind(Expr(:call, lv(:*), VECTORWIDTHSYMBOL, maybestatic(incr))))
115115
else
116116
push!(ret.args, _MMind(Expr(:call, lv(:vadd), Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, maybestatic(incr)), staticexpr(offset))))
117117
end

src/lowering.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ function loopvarremcomparison(loop::Loop, UFt::Int, nisvectorized::Bool, remfirs
458458
itercount = if loop.stopexact
459459
Expr(:call, lv(:vsub), loop.stophint - 1, Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, UFt))
460460
else
461-
Expr(:call, lv(:vsub), loop.stopsym, Expr(:call, lv(:vadd), Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, UFt), :(Static{1}())))
461+
Expr(:call, lv(:vsub), loop.stopsym, Expr(:call, lv(:vadd), Expr(:call, lv(:vmul), VECTORWIDTHSYMBOL, UFt), staticexpr(1)))
462462
end
463463
Expr(:call, :>, loopsym, itercount)
464464
elseif remfirst
@@ -528,7 +528,7 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
528528
elseif unrolledloop.stopexact
529529
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stophint+1, unrolledloop.sartsym), loopbuffer)
530530
else# both are given by symbols
531-
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stopsym, Expr(:call,lv(:vsub),unrolledloop.startsym, Expr(:call,lv(:Static),1))), loopbuffer)
531+
Expr(:call, lv(:scalar_less), Expr(:call, lv(:vsub), unrolledloop.stopsym, Expr(:call,lv(:vsub),unrolledloop.startsym, staticexpr(1))), loopbuffer)
532532
end
533533
ncomparison = Expr(:call, :!, comparison)
534534
Expr(:if, ncomparison, ifq)

src/map.jl

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,22 @@ function alignstores!(
88
args::Vararg{<:DenseArray{<:Base.HWReal},A}
99
) where {F, T <: Base.HWReal, A}
1010
N = length(y)
11-
ptry = pointer(y)
12-
ptrargs = pointer.(args)
13-
W = VectorizationBase.pick_vector_width(T)
11+
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
12+
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
1413
V = VectorizationBase.pick_vector_width_val(T)
15-
@assert iszero(reinterpret(UInt, ptry) & (sizeof(T) - 1)) "The destination vector (`dest`) must be aligned at least to `sizeof(eltype(dest))`."
16-
alignment = reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1)
14+
W = unwrap(V)
15+
zero_index = MM{W}(Static(0))
16+
uintptry = reinterpret(UInt, pointer(ptry))
17+
@assert iszero(uintptry & (sizeof(T) - 1)) "The destination vector (`dest`) must be aligned at least to `sizeof(eltype(dest))`."
18+
alignment = uintptry & (VectorizationBase.REGISTER_SIZE - 1)
1719
if alignment > 0
1820
i = reinterpret(Int, W - (alignment >>> VectorizationBase.intlog2(sizeof(T))))
1921
m = mask(T, i)
2022
if N < i
2123
m &= mask(T, N & (W - 1))
2224
end
23-
vnoaliasstore!(ptry, f(vload.(V, ptrargs, m)...), m)
24-
gep(ptry, i), gep.(ptrargs, i), N - i
25+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((zero_index,),), m)...), (zero_index,), m)
26+
gesp(ptry, (i,)), gesp.(ptrargs, ((i,),)), N - i
2527
else
2628
ptry, ptrargs, N
2729
end
@@ -32,46 +34,44 @@ function vmap_singlethread!(
3234
::Val{NonTemporal},
3335
args::Vararg{<:DenseArray{<:Base.HWReal},A}
3436
) where {F,T <: Base.HWReal, A, NonTemporal}
35-
if NonTemporal
37+
if NonTemporal # if stores into `y` aren't aligned, we'll get a crash
3638
ptry, ptrargs, N = alignstores!(f, y, args...)
3739
else
3840
N = length(y)
39-
ptry = pointer(y)
40-
ptrargs = pointer.(args)
41+
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
42+
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
4143
end
4244
i = 0
43-
W = VectorizationBase.pick_vector_width(T)
4445
V = VectorizationBase.pick_vector_width_val(T)
46+
W = unwrap(V)
47+
st = VectorizationBase.static_sizeof(T)
48+
zero_index = MM{W}(Static(0), st)
4549
while i < N - ((W << 2) - 1)
46-
v₁ = f(vload.(V, gep.(ptrargs, i ))...)
47-
v₂ = f(vload.(V, gep.(ptrargs, vadd(i, W)))...)
48-
v₃ = f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...)
49-
v₄ = f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...)
50+
51+
# vstore!(stridedpointer(B), VectorizationBase.VecUnroll((v1,v2,v3)), VectorizationBase.Unroll{AU,1,3,AV,W64,zero(UInt)}((i, j, k)))
52+
# vload(stridedpointer(B), VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,)))
53+
54+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,))
55+
v = f(vload.(ptrargs, index)...)
5056
if NonTemporal
51-
vstorent!(gep(ptry, i ), v₁)
52-
vstorent!(gep(ptry, vadd(i, W)), v₂)
53-
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
54-
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
57+
vstorent!(ptry, v, index)
5558
else
56-
vnoaliasstore!(gep(ptry, i ), v₁)
57-
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
58-
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
59-
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
59+
vnoaliasstore!(ptry, v, index)
6060
end
6161
i = vadd(i, 4W)
6262
end
6363
while i < N - (W - 1) # stops at 16 when
64-
vᵢ = f(vload.(V, gep.(ptrargs, i))...)
64+
vᵣ = f(vload.(ptrargs, ((MM{W}(i),),))...)
6565
if NonTemporal
66-
vstorent!(gep(ptry, i), vᵢ)
66+
vstorent!(ptry, vᵣ, (MM{W}(i),))
6767
else
68-
vnoaliasstore!(gep(ptry, i), vᵢ)
68+
vnoaliasstore!(ptry, vᵣ, (MM{W}(i),))
6969
end
7070
i = vadd(i, W)
7171
end
7272
if i < N
7373
m = mask(T, N & (W - 1))
74-
vnoaliasstore!(gep(ptry, i), f(vload.(V, gep.(ptrargs, i), m)...), m)
74+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
7575
end
7676
y
7777
end
@@ -89,25 +89,17 @@ function vmap_multithreaded!(
8989
Wsh = Wshift + 2
9090
Niter = N >>> Wsh
9191
Base.Threads.@threads for j 0:Niter-1
92-
i = j << Wsh
93-
v₁ = f(vload.(V, gep.(ptrargs, i ))...)
94-
v₂ = f(vload.(V, gep.(ptrargs, vadd(i, W)))...)
95-
v₃ = f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...)
96-
v₄ = f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...)
97-
vstorent!(gep(ptry, i ), v₁)
98-
vstorent!(gep(ptry, vadd(i, W)), v₂)
99-
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
100-
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
92+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((j << Wsh,))
93+
vstorent!(ptry, f(vload.(ptrargs, index)...), index)
10194
end
10295
ii = Niter << Wsh
10396
while ii < N - (W - 1) # stops at 16 when
104-
vᵢ = f(vload.(V, gep.(ptrargs, ii))...)
105-
vstorent!(gep(ptry, ii), vᵢ)
97+
vstorent!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),))...), (MM{W}(ii),))
10698
ii = vadd(ii, W)
10799
end
108100
if ii < N
109101
m = mask(T, N & (W - 1))
110-
vnoaliasstore!(gep(ptry, ii), f(vload.(V, gep.(ptrargs, ii), m)...), m)
102+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),), m)...), (MM{W}(ii),), m)
111103
end
112104
y
113105
end
@@ -118,33 +110,25 @@ function vmap_multithreaded!(
118110
args::Vararg{<:DenseArray{<:Base.HWReal},A}
119111
) where {F,T,A}
120112
N = length(y)
121-
ptry = pointer(y)
122-
ptrargs = pointer.(args)
113+
ptry = VectorizationBase.zero_offsets(stridedpointer(y))
114+
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
123115
N > 0 || return y
124116
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
125117
V = VectorizationBase.pick_vector_width_val(T)
126118
Wsh = Wshift + 2
127119
Niter = N >>> Wsh
128120
Base.Threads.@threads for j 0:Niter-1
129-
i = j << Wsh
130-
v₁ = f(vload.(V, gep.(ptrargs, i ))...)
131-
v₂ = f(vload.(V, gep.(ptrargs, vadd(i, W)))...)
132-
v₃ = f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...)
133-
v₄ = f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...)
134-
vnoaliasstore!(gep(ptry, i ), v₁)
135-
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
136-
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
137-
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
121+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((j << Wsh,))
122+
vnoaliasstore!(ptry, f(vload.(ptrargs, index)...), index)
138123
end
139124
ii = Niter << Wsh
140125
while ii < N - (W - 1) # stops at 16 when
141-
vᵢ = f(vload.(V, gep.(ptrargs, ii))...)
142-
vnoaliasstore!(gep(ptry, ii), vᵢ)
126+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),))...), (MM{W}(ii),))
143127
ii = vadd(ii, W)
144128
end
145129
if ii < N
146130
m = mask(T, N & (W - 1))
147-
vnoaliasstore!(gep(ptry, ii), f(vload.(V, gep.(ptrargs, ii), m)...), m)
131+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(ii),),), m)...), (MM{W}(ii),), m)
148132
end
149133
y
150134
end

src/mapreduce.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ function mapreduce_simple(f::F, op::OP, args::Vararg{DenseArray{<:NativeTypes},A
1818
ptrargs = ntuple(a -> pointer(args[a]), Val(A))
1919
N = length(first(args))
2020
iszero(N) && throw("Length of vector is 0!")
21+
st = ntuple(a -> VectorizationBase.static_sizeof(eltype(args[a])), Val(A))
2122
a_0 = f(vload.(ptrargs)...); i = 1
2223
while i < N
23-
a_0 = op(a_0, f(vload.(gep.(ptrargs, i))...)); i += 1
24+
a_0 = op(a_0, f(vload.(ptrargs, VectorizationBase.lazymul.(st, i))...)); i += 1
2425
end
2526
a_0
2627
end
@@ -43,28 +44,24 @@ function vmapreduce(f::F, op::OP, arg1::DenseArray{T}, args::Vararg{DenseArray{T
4344
end
4445
end
4546
function _vmapreduce(f::F, op::OP, ::StaticInt{W}, N, ::Type{T}, args::Vararg{DenseArray{<:NativeTypes},A}) where {F,OP,A,W,T}
46-
ptrargs = pointer.(args)
47-
a_0 = f(vload.(Val{W}(), ptrargs)...); i = W
47+
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
4848
if N 4W
49-
a_1 = f(vload.(Val{W}(), gep.(ptrargs, i))...); i += W
50-
a_2 = f(vload.(Val{W}(), gep.(ptrargs, i))...); i += W
51-
a_3 = f(vload.(Val{W}(), gep.(ptrargs, i))...); i += W
49+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((Zero(),)); i = 4W
50+
au = f(vload.(ptrargs, index)...)
5251
while i < N - ((W << 2) - 1)
53-
a_0 = op(a_0, f(vload.(Val{W}(), gep.(ptrargs, i))...)); i += W
54-
a_1 = op(a_1, f(vload.(Val{W}(), gep.(ptrargs, i))...)); i += W
55-
a_2 = op(a_2, f(vload.(Val{W}(), gep.(ptrargs, i))...)); i += W
56-
a_3 = op(a_3, f(vload.(Val{W}(), gep.(ptrargs, i))...)); i += W
52+
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((i,)); i += 4W
53+
au = op(au, f(vload.(ptrargs, index)...))
5754
end
58-
a_0 = op(a_0, a_1)
59-
a_2 = op(a_2, a_3)
60-
a_0 = op(a_0, a_2)
55+
a_0 = VectorizationBase.reduce_to_onevec(op, au)
56+
else
57+
a_0 = f(vload.(ptrargs, ((MM{W}(Zero()),),))...); i = W
6158
end
6259
while i < N - (W - 1)
63-
a_0 = op(a_0, f(vload.(Val{W}(), gep.(ptrargs, i))...)); i += W
60+
a_0 = op(a_0, f(vload.(ptrargs, ((MM{W}(i),),))...)); i += W
6461
end
6562
if i < N
6663
m = mask(T, N & (W - 1))
67-
a_0 = ifelse(m, op(a_0, f(vload.(Val{W}(), gep.(ptrargs, i))...)), a_0)
64+
a_0 = ifelse(m, op(a_0, f(vload.(ptrargs, ((MM{W}(i),),))...)), a_0)
6865
end
6966
vreduce(op, a_0)
7067
end

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ Execute an `@avx` block. The block's code is represented via the arguments:
468468
- `vargs...` holds the encoded pointers of all the arrays (see `VectorizationBase`'s various pointer types).
469469
"""
470470
@generated function _avx_!(::Val{UNROLL}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB, vargs...) where {UNROLL, OPS, ARF, AM, LPSYM, LB}
471-
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
471+
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
472472
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, vargs)
473473
# return @show avx_body(ls, UNROLL)
474474
# @show UNROLL, OPS, ARF, AM, LPSYM, LB

0 commit comments

Comments
 (0)