Skip to content

Commit 3c51cd2

Browse files
committed
add atomic_fence to vmapnt variants
1 parent fd10e63 commit 3c51cd2

File tree

2 files changed

+37
-35
lines changed

2 files changed

+37
-35
lines changed

src/simdfunctionals/map.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -185,41 +185,41 @@ end
185185
# end
186186

187187
function vmap_multithread!(
188-
f::F,
189-
y::AbstractArray{T},
190-
::Val{NonTemporal},
191-
args::Vararg{AbstractArray,A}
188+
f::F,
189+
y::AbstractArray{T},
190+
::Val{NonTemporal},
191+
args::Vararg{AbstractArray,A}
192192
) where {F,T,A,NonTemporal}
193-
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
194-
ptry, ptrargs, N = setup_vmap!(f, y, Val{NonTemporal}(), args...)
195-
# nt = min(Threads.nthreads(), VectorizationBase.SYS_CPU_THREADS, N >> (Wshift + 3))
196-
nt = min(Threads.nthreads(), num_cores(), N >> (Wshift + 5))
193+
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
194+
ptry, ptrargs, N = setup_vmap!(f, y, Val{NonTemporal}(), args...)
195+
# nt = min(Threads.nthreads(), VectorizationBase.SYS_CPU_THREADS, N >> (Wshift + 3))
196+
nt = min(Threads.nthreads(), num_cores(), N >> (Wshift + 5))
197197

198-
# if !((nt > 1) && iszero(ccall(:jl_in_threaded_region, Cint, ())))
199-
if nt < 2
200-
_vmap_singlethread!(f, ptry, Zero(), N, Val{NonTemporal}(), ptrargs)
201-
return
202-
end
198+
# if !((nt > 1) && iszero(ccall(:jl_in_threaded_region, Cint, ())))
199+
if nt < 2
200+
_vmap_singlethread!(f, ptry, Zero(), N, Val{NonTemporal}(), ptrargs)
201+
return
202+
end
203203

204-
cfunc = vmap_closure(f, ptry, ptrargs, Val{NonTemporal}())
205-
vmc = VmapClosure{NonTemporal}(f, ptry, ptrargs)
206-
Nveciter = (N + (W-1)) >> Wshift
207-
Nd, Nr = divrem(Nveciter, nt)
208-
NdW = Nd << Wshift
209-
NdWr = NdW + W
210-
GC.@preserve cfunc begin
211-
start = 0
212-
for tid 1:nt-1
213-
stop = start + ifelse(tid Nr, NdWr, NdW)
214-
launch_thread_vmap!(tid, cfunc, ptry, ptrargs, start, stop)
215-
start = stop
216-
end
217-
_vmap_singlethread!(f, ptry, start, N, Val{NonTemporal}(), ptrargs)
218-
for tid 1:nt-1
219-
ThreadingUtilities.wait(tid)
220-
end
204+
cfunc = vmap_closure(f, ptry, ptrargs, Val{NonTemporal}())
205+
vmc = VmapClosure{NonTemporal}(f, ptry, ptrargs)
206+
Nveciter = (N + (W-1)) >> Wshift
207+
Nd, Nr = divrem(Nveciter, nt)
208+
NdW = Nd << Wshift
209+
NdWr = NdW + W
210+
GC.@preserve cfunc begin
211+
start = 0
212+
for tid 1:nt-1
213+
stop = start + ifelse(tid Nr, NdWr, NdW)
214+
launch_thread_vmap!(tid, cfunc, ptry, ptrargs, start, stop)
215+
start = stop
216+
end
217+
_vmap_singlethread!(f, ptry, start, N, Val{NonTemporal}(), ptrargs)
218+
for tid 1:nt-1
219+
ThreadingUtilities.wait(tid)
221220
end
222-
nothing
221+
end
222+
nothing
223223
end
224224
@generated function gc_preserve_vmap!(f::F,
225225
y::AbstractArray,
@@ -356,6 +356,7 @@ function vmapnt!(
356356
) where {F,A}
357357
if check_args(y, args...) && all_dense(y, args...)
358358
gc_preserve_vmap!(f, y, Val{true}(), Val{false}(), args...)
359+
Threads.atomic_fence()
359360
else
360361
map!(f, y, args...)
361362
end
@@ -370,6 +371,7 @@ function vmapntt!(
370371
) where {F,A}
371372
if check_args(y, args...) && all_dense(y, args...)
372373
gc_preserve_vmap!(f, y, Val{true}(), Val{true}(), args...)
374+
Threads.atomic_fence()
373375
else
374376
map!(f, y, args...)
375377
end

test/map.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
@show T, @__LINE__
55
for N [ 3, 371 ]
66
a = rand(T, N); b = rand(T, N);
7-
c0 = vmapntt(foo, a, b);
8-
c3 = similar(c0) # not aligned
9-
fill!(c3, NaN); @views vmapntt!(foo, c3[2:end], a[2:end], b[2:end]);
107
c1 = map(foo, a, b);
118
c2 = vmap(foo, a, b);
129
@test c1 c2
@@ -16,7 +13,10 @@
1613
@test c1 c2
1714
fill!(c2, NaN); @views vmapnt!(foo, c2[2:end], a[2:end], b[2:end]);
1815
@test @views c1[2:end] c2[2:end]
19-
sleep(1e-3) # non-temporal stores won't be automatically synced/coherant, so need to wait!
16+
# sleep(1e-3) # non-temporal stores won't be automatically synced/coherant, so need to wait!
17+
c0 = vmapntt(foo, a, b);
18+
c3 = similar(c0); # not aligned
19+
fill!(c3, NaN); @views vmapntt!(foo, c3[2:end], a[2:end], b[2:end]);
2020
@test c0 c1
2121
@test isnan(c3[begin])
2222
@test @views c1[2:end] c3[2:end]

0 commit comments

Comments
 (0)