Skip to content

Commit 7e6dc6b

Browse files
committed
Refactored map.jl, added threaded temporal map.
1 parent a21810f commit 7e6dc6b

File tree

3 files changed

+133
-75
lines changed

3 files changed

+133
-75
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast
2828

2929
export LowDimArray, stridedpointer,
3030
@avx, @_avx, *ˡ, _avx_!,
31-
vmap, vmap!, vmapnt, vmapnt!, vmapntt, vmapntt!,
31+
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
3232
vfilter, vfilter!, vmapreduce, vreduce
3333

3434
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")

src/map.jl

Lines changed: 130 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,8 @@
1-
# Expression-generator for vmap!
2-
function vmap_quote(N, ::Type{T}) where {T}
3-
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
4-
val = Expr(:call, Expr(:curly, :Val, W))
5-
q = Expr(:block, Expr(:(=), :M, Expr(:call, :length, :dest)), Expr(:(=), :vdest, Expr(:call, :pointer, :dest)), Expr(:(=), :m, 0))
6-
fcall = Expr(:call, :f)
7-
loopbody = Expr(:block, Expr(:call, :vstore!, Expr(:call, :gep, :vdest, :m), fcall), Expr(:(+=), :m, W))
8-
fcallmask = Expr(:call, :f)
9-
bodymask = Expr(:block, Expr(:(=), :__mask__, Expr(:call, :mask, val, Expr(:call, :&, :M, W-1))), Expr(:call, :vstore!, Expr(:call, :gep, :vdest, :m), fcallmask, :__mask__))
10-
for n 1:N
11-
arg_n = Symbol(:varg_,n)
12-
push!(q.args, Expr(:(=), arg_n, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:call, :pointer, Expr(:ref, :args, n)))))
13-
push!(fcall.args, Expr(:call, :vload, val, Expr(:call, :gep, arg_n, :m)))
14-
push!(fcallmask.args, Expr(:call, :vload, val, Expr(:call, :gep, arg_n, :m), :__mask__))
15-
end
16-
loop = Expr(:for, Expr(:(=), :_, Expr(:call, :(:), 0, Expr(:call, :-, Expr(:call, :(>>>), :M, Wshift), 1))), loopbody)
17-
push!(q.args, loop)
18-
ifmask = Expr(:if, Expr(:call, :(!=), :m, :M), bodymask)
19-
push!(q.args, ifmask)
20-
push!(q.args, :dest)
21-
q
22-
end
23-
"""
24-
vmap!(f, destination, a::AbstractArray)
25-
vmap!(f, destination, a::AbstractArray, b::AbstractArray, ...)
261

27-
Vectorized-`map!`, applying `f` to each element of `a` (or paired elements of `a`, `b`, ...)
28-
and storing the result in `destination`.
292
"""
30-
@generated function vmap!(f::F, dest::AbstractArray{T}, args::Vararg{<:AbstractArray,N}) where {F,T,N}
31-
# do not change argnames here without compensatory changes in vmap_quote
32-
vmap_quote(N, T)
33-
end
34-
3+
`vstorent!` (non-temporal store) requires data to be aligned.
4+
`alignstores!` will align `y` in preparation for the non-temporal maps.
5+
"""
356
function alignstores!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
367
N = length(y)
378
ptry = pointer(y)
@@ -46,13 +17,129 @@ function alignstores!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {
4617
if N < i
4718
m &= mask(T, N & (W - 1))
4819
end
49-
vstore!(ptry, extract_data(f(vload.(V, ptrargs, m)...)), m)
20+
vnoaliasstore!(ptry, extract_data(f(vload.(V, ptrargs, m)...)), m)
5021
gep(ptry, i), gep.(ptrargs, i), N - i
5122
else
5223
ptry, ptrargs, N
5324
end
5425
end
5526

27+
function vmap_singlethread!(f::F, y::AbstractVector{T}, ::Val{NonTemporal}, args::Vararg{<:Any,A}) where {F,T,A,NonTemporal}
28+
if NonTemporal
29+
ptry, ptrargs, N = alignstores!(f, y, args...)
30+
else
31+
N = length(y)
32+
ptry = pointer(y)
33+
ptrargs = pointer.(args)
34+
end
35+
i = 0
36+
W = VectorizationBase.pick_vector_width(T)
37+
V = VectorizationBase.pick_vector_width_val(T)
38+
while i < N - ((W << 2) - 1)
39+
v₁ = extract_data(f(vload.(V, gep.(ptrargs, i ))...))
40+
v₂ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, W)))...))
41+
v₃ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...))
42+
v₄ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...))
43+
if NonTemporal
44+
vstorent!(gep(ptry, i ), v₁)
45+
vstorent!(gep(ptry, vadd(i, W)), v₂)
46+
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
47+
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
48+
else
49+
vnoaliasstore!(gep(ptry, i ), v₁)
50+
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
51+
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
52+
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
53+
end
54+
i = vadd(i, 4W)
55+
end
56+
while i < N - (W - 1) # stops at 16 when
57+
vᵢ = extract_data(f(vload.(V, gep.(ptrargs, i))...))
58+
if NonTemporal
59+
vstorent!(gep(ptry, i), vᵢ)
60+
else
61+
vnoaliasstore!(gep(ptry, i), vᵢ)
62+
end
63+
i = vadd(i, W)
64+
end
65+
if i < N
66+
m = mask(T, N & (W - 1))
67+
vnoaliasstore!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i), m)...)), m)
68+
end
69+
y
70+
end
71+
72+
function vmap_multithreaded!(f::F, y::AbstractVector{T}, ::Val{NonTemporal}, args::Vararg{<:Any,A}) where {F,T,A,NonTemporal}
73+
if NonTemporal
74+
ptry, ptrargs, N = alignstores!(f, y, args...)
75+
else
76+
N = length(y)
77+
ptry = pointer(y)
78+
ptrargs = pointer.(args)
79+
end
80+
N > 0 || return y
81+
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
82+
V = VectorizationBase.pick_vector_width_val(T)
83+
Wsh = Wshift + 2
84+
Niter = N >>> Wsh
85+
Base.Threads.@threads for j 0:Niter-1
86+
i = j << Wsh
87+
v₁ = extract_data(f(vload.(V, gep.(ptrargs, i ))...))
88+
v₂ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, W)))...))
89+
v₃ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...))
90+
v₄ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...))
91+
if NonTemporal
92+
vstorent!(gep(ptry, i ), v₁)
93+
vstorent!(gep(ptry, vadd(i, W)), v₂)
94+
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
95+
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
96+
else
97+
vnoaliasstore!(gep(ptry, i ), v₁)
98+
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
99+
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
100+
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
101+
end
102+
end
103+
ii = Niter << Wsh
104+
while ii < N - (W - 1) # stops at 16 when
105+
vᵢ = extract_data(f(vload.(V, gep.(ptrargs, ii))...))
106+
if NonTemporal
107+
vstorent!(gep(ptry, ii), vᵢ)
108+
else
109+
vnoaliasstore!(gep(ptry, ii), vᵢ)
110+
end
111+
ii = vadd(ii, W)
112+
end
113+
if ii < N
114+
m = mask(T, N & (W - 1))
115+
vnoaliasstore!(gep(ptry, ii), extract_data(f(vload.(V, gep.(ptrargs, ii), m)...)), m)
116+
end
117+
y
118+
end
119+
120+
121+
"""
122+
vmap!(f, destination, a::AbstractArray)
123+
vmap!(f, destination, a::AbstractArray, b::AbstractArray, ...)
124+
125+
Vectorized-`map!`, applying `f` to each element of `a` (or paired elements of `a`, `b`, ...)
126+
and storing the result in `destination`.
127+
"""
128+
function vmap!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
129+
vmap_singlethread!(f, y, Val{false}(), args...)
130+
end
131+
132+
133+
"""
134+
vmapt!(::Function, dest, args...)
135+
136+
Like `vmap!` (see `vmap!`), but uses `Threads.@threads` for parallel execution.
137+
"""
138+
function vmapt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
139+
vmap_multithreaded!(f, y, Val{false}(), args...)
140+
end
141+
142+
56143
"""
57144
vmapnt!(::Function, dest, args...)
58145
@@ -109,24 +196,7 @@ BenchmarkTools.Trial:
109196
```
110197
"""
111198
function vmapnt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
112-
ptry, ptrargs, N = alignstores!(f, y, args...)
113-
i = 0
114-
W = VectorizationBase.pick_vector_width(T)
115-
V = VectorizationBase.pick_vector_width_val(T)
116-
while i < N - ((W << 2) - 1)
117-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
118-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
119-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
120-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
121-
end
122-
while i < N - (W - 1) # stops at 16 when
123-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
124-
end
125-
if i < N
126-
m = mask(T, N & (W - 1))
127-
vstore!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i), m)...)), m)
128-
end
129-
y
199+
vmap_singlethread!(f, y, Val{true}(), args...)
130200
end
131201

132202
"""
@@ -135,28 +205,7 @@ end
135205
Like `vmapnt!` (see `vmapnt!`), but uses `Threads.@threads` for parallel execution.
136206
"""
137207
function vmapntt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
138-
ptry, ptrargs, N = alignstores!(f, y, args...)
139-
N > 0 || return y
140-
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
141-
V = VectorizationBase.pick_vector_width_val(T)
142-
Wsh = Wshift + 2
143-
Niter = N >>> Wsh
144-
Base.Threads.@threads for j 0:Niter-1
145-
i = j << Wsh
146-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
147-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
148-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...))); i += W
149-
vstorent!(gep(ptry, i), extract_data(f(vload.(V, gep.(ptrargs, i))...)))
150-
end
151-
ii = Niter << Wsh
152-
while ii < N - (W - 1) # stops at 16 when
153-
vstorent!(gep(ptry, ii), extract_data(f(vload.(V, gep.(ptrargs, ii))...))); ii += W
154-
end
155-
if ii < N
156-
m = mask(T, N & (W - 1))
157-
vstore!(gep(ptry, ii), extract_data(f(vload.(V, gep.(ptrargs, ii), m)...)), m)
158-
end
159-
y
208+
vmap_multithreaded!(f, y, Val{true}(), args...)
160209
end
161210

162211
function vmap_call(f::F, vm!::V, args::Vararg{<:Any,N}) where {V,F,N}
@@ -174,6 +223,14 @@ and returning a new array.
174223
"""
175224
vmap(f::F, args::Vararg{<:Any,N}) where {F,N} = vmap_call(f, vmap!, args...)
176225

226+
"""
227+
vmapt(f, a::AbstractArray)
228+
vmapt(f, a::AbstractArray, b::AbstractArray, ...)
229+
230+
A threaded variant of [`vmap`](@ref).
231+
"""
232+
vmapt(f::F, args::Vararg{<:Any,N}) where {F,N} = vmap_call(f, vmapt!, args...)
233+
177234
"""
178235
vmapnt(f, a::AbstractArray)
179236
vmapnt(f, a::AbstractArray, b::AbstractArray, ...)

test/map.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
c1 = map(foo, a, b);
88
c2 = vmap(foo, a, b);
99
@test c1 c2
10+
c2 = vmapt(foo, a, b);
11+
@test c1 c2
1012
c2 = vmapnt(foo, a, b);
1113
@test c1 c2
1214
c2 = vmapntt(foo, a, b);
@@ -21,6 +23,5 @@
2123
map!(xᵢ -> clenshaw(xᵢ, c), y1, x)
2224
vmap!(xᵢ -> clenshaw(xᵢ, c), y2, x)
2325
@test y1 y2
24-
2526
end
2627
end

0 commit comments

Comments
 (0)