Skip to content

Commit 9f04bd0

Browse files
committed
Support and test multithreading vmap on ranges
1 parent 5720566 commit 9f04bd0

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

src/simdfunctionals/map.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ function _vmap_singlethread!(
109109
nothing
110110
end
111111

112-
abstract type AbstractVmapClosure{NonTemporal,F,D,N,A<:Tuple{Vararg{StridedPointer,N}}} <: Function end
112+
abstract type AbstractVmapClosure{NonTemporal,F,D,N,A<:Tuple{Vararg{Any,N}}} <: Function end
113113
struct VmapClosure{NonTemporal,F,D,N,A} <: AbstractVmapClosure{NonTemporal,F,D,N,A}
114114
f::F
115-
function VmapClosure{NonTemporal}(f::F, ::D, ::A) where {NonTemporal,F,D,N,A<:Tuple{Vararg{StridedPointer,N}}}
115+
function VmapClosure{NonTemporal}(f::F, ::D, ::A) where {NonTemporal,F,D,N,A<:Tuple{Vararg{Any,N}}}
116116
new{NonTemporal,F,D,N,A}(f)
117117
end
118118
end
@@ -160,7 +160,7 @@ end
160160
end
161161
end
162162

163-
@inline function vmap_closure(f::F, ptry::D, ptrargs::A, ::Val{NonTemporal}) where {F,D<:StridedPointer,N,A<:Tuple{Vararg{StridedPointer,N}},NonTemporal}
163+
@inline function vmap_closure(f::F, ptry::D, ptrargs::A, ::Val{NonTemporal}) where {F,D<:StridedPointer,N,A<:Tuple{Vararg{Any,N}},NonTemporal}
164164
vmc = VmapClosure{NonTemporal}(f, ptry, ptrargs)
165165
@cfunction($vmc, Cvoid, (Ptr{UInt},))
166166
end
@@ -354,25 +354,25 @@ BenchmarkTools.Trial:
354354
function vmapnt!(
355355
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
356356
) where {F,A}
357-
if check_args(y, args...) && all_dense(y, args...)
358-
gc_preserve_vmap!(f, y, Val{true}(), Val{false}(), args...)
359-
else
360-
map!(f, y, args...)
361-
end
357+
if check_args(y, args...) && all_dense(y, args...)
358+
gc_preserve_vmap!(f, y, Val{true}(), Val{false}(), args...)
359+
else
360+
map!(f, y, args...)
361+
end
362362
end
363363

364364
"""
365365
vmapntt!(::Function, dest, args...)
366366
A threaded variant of [`vmapnt!`](@ref).
367367
"""
368368
function vmapntt!(
369-
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
369+
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
370370
) where {F,A}
371-
if check_args(y, args...) && all_dense(y, args...)
372-
gc_preserve_vmap!(f, y, Val{true}(), Val{true}(), args...)
373-
else
374-
map!(f, y, args...)
375-
end
371+
if check_args(y, args...) && all_dense(y, args...)
372+
gc_preserve_vmap!(f, y, Val{true}(), Val{true}(), args...)
373+
else
374+
map!(f, y, args...)
375+
end
376376
end
377377

378378
# generic fallbacks
@@ -382,9 +382,9 @@ end
382382
@inline vmapntt!(f, args...) = map!(f, args...)
383383

384384
function vmap_call(f::F, vm!::V, args::Vararg{Any,N}) where {V,F,N}
385-
T = Base._return_type(f, Base.Broadcast.eltypes(args))
386-
dest = similar(first(args), T)
387-
vm!(f, dest, args...)
385+
T = Base._return_type(f, Base.Broadcast.eltypes(args))
386+
dest = similar(first(args), T)
387+
vm!(f, dest, args...)
388388
end
389389

390390
"""

test/map.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@
2828
@test y1 y2
2929
end
3030
@test vmap(abs2, 1:100) == map(abs2, 1:100)
31-
@test vmap(abs2, 1:3:1000) == map(abs2, 1:3:1000)
32-
@test vmap(abs2, 1.0:3.0:1000.0) map(abs2, 1.0:3.0:1000.0)
31+
@test vmapt(abs2, 1:3:10000) == map(abs2, 1:3:1000)
32+
@test vmapt(abs2, 1.0:3.0:10000.0) map(abs2, 1.0:3.0:1000.0)
3333
end

0 commit comments

Comments
 (0)