@@ -185,41 +185,41 @@ end
185
185
# end
186
186
187
187
function vmap_multithread! (
188
- f:: F ,
189
- y:: AbstractArray{T} ,
190
- :: Val{NonTemporal} ,
191
- args:: Vararg{AbstractArray,A}
188
+ f:: F ,
189
+ y:: AbstractArray{T} ,
190
+ :: Val{NonTemporal} ,
191
+ args:: Vararg{AbstractArray,A}
192
192
) where {F,T,A,NonTemporal}
193
- W, Wshift = VectorizationBase. pick_vector_width_shift (T)
194
- ptry, ptrargs, N = setup_vmap! (f, y, Val {NonTemporal} (), args... )
195
- # nt = min(Threads.nthreads(), VectorizationBase.SYS_CPU_THREADS, N >> (Wshift + 3))
196
- nt = min (Threads. nthreads (), num_cores (), N >> (Wshift + 5 ))
193
+ W, Wshift = VectorizationBase. pick_vector_width_shift (T)
194
+ ptry, ptrargs, N = setup_vmap! (f, y, Val {NonTemporal} (), args... )
195
+ # nt = min(Threads.nthreads(), VectorizationBase.SYS_CPU_THREADS, N >> (Wshift + 3))
196
+ nt = min (Threads. nthreads (), num_cores (), N >> (Wshift + 5 ))
197
197
198
- # if !((nt > 1) && iszero(ccall(:jl_in_threaded_region, Cint, ())))
199
- if nt < 2
200
- _vmap_singlethread! (f, ptry, Zero (), N, Val {NonTemporal} (), ptrargs)
201
- return
202
- end
198
+ # if !((nt > 1) && iszero(ccall(:jl_in_threaded_region, Cint, ())))
199
+ if nt < 2
200
+ _vmap_singlethread! (f, ptry, Zero (), N, Val {NonTemporal} (), ptrargs)
201
+ return
202
+ end
203
203
204
- cfunc = vmap_closure (f, ptry, ptrargs, Val {NonTemporal} ())
205
- vmc = VmapClosure {NonTemporal} (f, ptry, ptrargs)
206
- Nveciter = (N + (W- 1 )) >> Wshift
207
- Nd, Nr = divrem (Nveciter, nt)
208
- NdW = Nd << Wshift
209
- NdWr = NdW + W
210
- GC. @preserve cfunc begin
211
- start = 0
212
- for tid ∈ 1 : nt- 1
213
- stop = start + ifelse (tid ≤ Nr, NdWr, NdW)
214
- launch_thread_vmap! (tid, cfunc, ptry, ptrargs, start, stop)
215
- start = stop
216
- end
217
- _vmap_singlethread! (f, ptry, start, N, Val {NonTemporal} (), ptrargs)
218
- for tid ∈ 1 : nt- 1
219
- ThreadingUtilities. wait (tid)
220
- end
204
+ cfunc = vmap_closure (f, ptry, ptrargs, Val {NonTemporal} ())
205
+ vmc = VmapClosure {NonTemporal} (f, ptry, ptrargs)
206
+ Nveciter = (N + (W- 1 )) >> Wshift
207
+ Nd, Nr = divrem (Nveciter, nt)
208
+ NdW = Nd << Wshift
209
+ NdWr = NdW + W
210
+ GC. @preserve cfunc begin
211
+ start = 0
212
+ for tid ∈ 1 : nt- 1
213
+ stop = start + ifelse (tid ≤ Nr, NdWr, NdW)
214
+ launch_thread_vmap! (tid, cfunc, ptry, ptrargs, start, stop)
215
+ start = stop
216
+ end
217
+ _vmap_singlethread! (f, ptry, start, N, Val {NonTemporal} (), ptrargs)
218
+ for tid ∈ 1 : nt- 1
219
+ ThreadingUtilities. wait (tid)
221
220
end
222
- nothing
221
+ end
222
+ nothing
223
223
end
224
224
@generated function gc_preserve_vmap! (f:: F ,
225
225
y:: AbstractArray ,
@@ -356,6 +356,7 @@ function vmapnt!(
356
356
) where {F,A}
357
357
if check_args (y, args... ) && all_dense (y, args... )
358
358
gc_preserve_vmap! (f, y, Val {true} (), Val {false} (), args... )
359
+ Threads. atomic_fence ()
359
360
else
360
361
map! (f, y, args... )
361
362
end
@@ -370,6 +371,7 @@ function vmapntt!(
370
371
) where {F,A}
371
372
if check_args (y, args... ) && all_dense (y, args... )
372
373
gc_preserve_vmap! (f, y, Val {true} (), Val {true} (), args... )
374
+ Threads. atomic_fence ()
373
375
else
374
376
map! (f, y, args... )
375
377
end
0 commit comments