Skip to content

Commit f705e3d

Browse files
committed
vmap_multithread! was not type stable with NonTemporal as a val argument, so split it into two separate definitions. Fixes #141.
1 parent 3b33759 commit f705e3d

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

src/map.jl

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,47 @@ end
7979
function vmap_multithreaded!(
8080
f::F,
8181
y::DenseArray{T},
82-
::Val{NonTemporal},
82+
::Val{true},
8383
args::Vararg{<:DenseArray{<:NativeTypes},A}
84-
) where {F,T,A,NonTemporal}
85-
if NonTemporal
86-
ptry, ptrargs, N = alignstores!(f, y, args...)
87-
else
88-
N = length(y)
89-
ptry = pointer(y)
90-
ptrargs = pointer.(args)
84+
) where {F,T,A}
85+
ptry, ptrargs, N = alignstores!(f, y, args...)
86+
N > 0 || return y
87+
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
88+
V = VectorizationBase.pick_vector_width_val(T)
89+
Wsh = Wshift + 2
90+
Niter = N >>> Wsh
91+
Base.Threads.@threads for j 0:Niter-1
92+
i = j << Wsh
93+
v₁ = extract_data(f(vload.(V, gep.(ptrargs, i ))...))
94+
v₂ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, W)))...))
95+
v₃ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...))
96+
v₄ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...))
97+
vstorent!(gep(ptry, i ), v₁)
98+
vstorent!(gep(ptry, vadd(i, W)), v₂)
99+
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
100+
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
101+
end
102+
ii = Niter << Wsh
103+
while ii < N - (W - 1) # stops at 16 when
104+
vᵢ = extract_data(f(vload.(V, gep.(ptrargs, ii))...))
105+
vstorent!(gep(ptry, ii), vᵢ)
106+
ii = vadd(ii, W)
107+
end
108+
if ii < N
109+
m = mask(T, N & (W - 1))
110+
vnoaliasstore!(gep(ptry, ii), extract_data(f(vload.(V, gep.(ptrargs, ii), m)...)), m)
91111
end
112+
y
113+
end
114+
function vmap_multithreaded!(
115+
f::F,
116+
y::DenseArray{T},
117+
::Val{false},
118+
args::Vararg{<:DenseArray{<:NativeTypes},A}
119+
) where {F,T,A}
120+
N = length(y)
121+
ptry = pointer(y)
122+
ptrargs = pointer.(args)
92123
N > 0 || return y
93124
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
94125
V = VectorizationBase.pick_vector_width_val(T)
@@ -100,26 +131,15 @@ function vmap_multithreaded!(
100131
v₂ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, W)))...))
101132
v₃ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 2W)))...))
102133
v₄ = extract_data(f(vload.(V, gep.(ptrargs, vadd(i, 3W)))...))
103-
if NonTemporal
104-
vstorent!(gep(ptry, i ), v₁)
105-
vstorent!(gep(ptry, vadd(i, W)), v₂)
106-
vstorent!(gep(ptry, vadd(i, 2W)), v₃)
107-
vstorent!(gep(ptry, vadd(i, 3W)), v₄)
108-
else
109-
vnoaliasstore!(gep(ptry, i ), v₁)
110-
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
111-
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
112-
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
113-
end
134+
vnoaliasstore!(gep(ptry, i ), v₁)
135+
vnoaliasstore!(gep(ptry, vadd(i, W)), v₂)
136+
vnoaliasstore!(gep(ptry, vadd(i, 2W)), v₃)
137+
vnoaliasstore!(gep(ptry, vadd(i, 3W)), v₄)
114138
end
115139
ii = Niter << Wsh
116140
while ii < N - (W - 1) # stops at 16 when
117141
vᵢ = extract_data(f(vload.(V, gep.(ptrargs, ii))...))
118-
if NonTemporal
119-
vstorent!(gep(ptry, ii), vᵢ)
120-
else
121-
vnoaliasstore!(gep(ptry, ii), vᵢ)
122-
end
142+
vnoaliasstore!(gep(ptry, ii), vᵢ)
123143
ii = vadd(ii, W)
124144
end
125145
if ii < N

0 commit comments

Comments
 (0)