Skip to content

Commit 561bcbb

Browse files
committed
Allow vmap to apply to more array types
1 parent c204d59 commit 561bcbb

File tree

1 file changed

+60
-42
lines changed

1 file changed

+60
-42
lines changed

src/map.jl

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11

2-
const DenseNativeArray = DenseArray{<:NativeTypes}
3-
42
"""
53
`vstorent!` (non-temporal store) requires data to be aligned.
64
`alignstores!` will align `y` in preparation for the non-temporal maps.
75
"""
86
function alignstores!(
9-
f::F, y::DenseArray{T},
10-
args::Vararg{DenseNativeArray,A}
7+
f::F, y::AbstractArray{T},
8+
args::Vararg{AbstractArray,A}
119
) where {F, T <: Base.HWReal, A}
1210
N = length(y)
1311
ptry = VectorizationBase.zstridedpointer(y)
@@ -32,9 +30,9 @@ function alignstores!(
3230
end
3331

3432
function vmap_singlethread!(
35-
f::F, y::DenseArray{T},
33+
f::F, y::AbstractArray{T},
3634
::Val{NonTemporal},
37-
args::Vararg{DenseNativeArray,A}
35+
args::Vararg{AbstractArray,A}
3836
) where {F,T <: Base.HWReal, A, NonTemporal}
3937
if NonTemporal # if stores into `y` aren't aligned, we'll get a crash
4038
ptry, ptrargs, N = alignstores!(f, y, args...)
@@ -60,36 +58,36 @@ function vmap_singlethread!(
6058
end
6159
i = vadd_fast(i, StaticInt{UNROLL}() * W)
6260
end
63-
if Base.libllvm_version v"11"
64-
Nm1 = vsub_fast(N, 1)
65-
while i < N # stops at 16 when
66-
m = mask(V, i, Nm1)
67-
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
68-
i = vadd_fast(i, W)
69-
end
70-
else
71-
while i < N - (W - 1) # stops at 16 when
72-
vᵣ = f(vload.(ptrargs, ((MM{W}(i),),))...)
73-
if NonTemporal
74-
vstorent!(ptry, vᵣ, (MM{W}(i),))
75-
else
76-
vnoaliasstore!(ptry, vᵣ, (MM{W}(i),))
77-
end
78-
i = vadd_fast(i, W)
79-
end
80-
if i < N
81-
m = mask(T, N & (W - 1))
82-
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
61+
# if Base.libllvm_version ≥ v"11" # this seems to be slower
62+
# Nm1 = vsub_fast(N, 1)
63+
# while i < N # stops at 16 when
64+
# m = mask(V, i, Nm1)
65+
# vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
66+
# i = vadd_fast(i, W)
67+
# end
68+
# else
69+
while i < N - (W - 1) # stops at 16 when
70+
vᵣ = f(vload.(ptrargs, ((MM{W}(i),),))...)
71+
if NonTemporal
72+
vstorent!(ptry, vᵣ, (MM{W}(i),))
73+
else
74+
vnoaliasstore!(ptry, vᵣ, (MM{W}(i),))
8375
end
76+
i = vadd_fast(i, W)
77+
end
78+
if i < N
79+
m = mask(T, N & (W - 1))
80+
vnoaliasstore!(ptry, f(vload.(ptrargs, ((MM{W}(i),),), m)...), (MM{W}(i,),), m)
8481
end
82+
# end
8583
y
8684
end
8785

8886
function vmap_multithreaded!(
8987
f::F,
90-
y::DenseArray{T},
88+
y::AbstractArray{T},
9189
::Val{true},
92-
args::Vararg{DenseNativeArray,A}
90+
args::Vararg{AbstractArray,A}
9391
) where {F,T,A}
9492
ptry, ptrargs, N = alignstores!(f, y, args...)
9593
N > 0 || return y
@@ -114,9 +112,9 @@ function vmap_multithreaded!(
114112
end
115113
function vmap_multithreaded!(
116114
f::F,
117-
y::DenseArray{T},
115+
y::AbstractArray{T},
118116
::Val{false},
119-
args::Vararg{DenseNativeArray,A}
117+
args::Vararg{AbstractArray,A}
120118
) where {F,T,A}
121119
N = length(y)
122120
ptry = VectorizationBase.zstridedpointer(y)
@@ -142,6 +140,10 @@ function vmap_multithreaded!(
142140
y
143141
end
144142

143+
Base.@pure _all_dense(::ArrayInterface.DenseDims{D}) where {D} = all(D)
144+
@inline all_dense() = true
145+
@inline all_dense(A::AbstractArray) = _all_dense(ArrayInterface.dense_dims(A))
146+
@inline all_dense(A::AbstractArray, B::AbstractArray, C::Vararg{AbstractArray,K}) where {K} = all_dense(A) && all_dense(B, C...)
145147

146148
"""
147149
vmap!(f, destination, a::AbstractArray)
@@ -151,9 +153,13 @@ Vectorized-`map!`, applying `f` to each element of `a` (or paired elements of `a
151153
and storing the result in `destination`.
152154
"""
153155
function vmap!(
154-
f::F, y::DenseArray{T}, args::Vararg{DenseNativeArray,A}
155-
) where {F,T<:Base.HWReal,A}
156-
vmap_singlethread!(f, y, Val{false}(), args...)
156+
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
157+
) where {F,A}
158+
if check_args(y, args...) && all_dense(y, args...)
159+
vmap_singlethread!(f, y, Val{false}(), args...)
160+
else
161+
map!(f, y, args...)
162+
end
157163
end
158164

159165

@@ -163,9 +169,13 @@ end
163169
Like `vmap!` (see `vmap!`), but uses `Threads.@threads` for parallel execution.
164170
"""
165171
function vmapt!(
166-
f::F, y::DenseArray{T}, args::Vararg{DenseNativeArray,A}
167-
) where {F,T<:Base.HWReal,A}
168-
vmap_multithreaded!(f, y, Val{false}(), args...)
172+
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
173+
) where {F,A}
174+
if check_args(y, args...) && all_dense(y, args...)
175+
vmap_multithreaded!(f, y, Val{false}(), args...)
176+
else
177+
map!(f, y, args...)
178+
end
169179
end
170180

171181

@@ -225,9 +235,13 @@ BenchmarkTools.Trial:
225235
```
226236
"""
227237
function vmapnt!(
228-
f::F, y::DenseArray{T}, args::Vararg{DenseNativeArray,A}
229-
) where {F,T<:Base.HWReal,A}
230-
vmap_singlethread!(f, y, Val{true}(), args...)
238+
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
239+
) where {F,A}
240+
if check_args(y, args...) && all_dense(y, args...)
241+
vmap_singlethread!(f, y, Val{true}(), args...)
242+
else
243+
map!(f, y, args...)
244+
end
231245
end
232246

233247
"""
@@ -236,9 +250,13 @@ end
236250
Like `vmapnt!` (see `vmapnt!`), but uses `Threads.@threads` for parallel execution.
237251
"""
238252
function vmapntt!(
239-
f::F, y::DenseArray{T}, args::Vararg{DenseNativeArray,A}
240-
) where {F,T<:Base.HWReal,A}
241-
vmap_multithreaded!(f, y, Val{true}(), args...)
253+
f::F, y::AbstractArray, args::Vararg{AbstractArray,A}
254+
) where {F,A}
255+
if check_args(y, args...) && all_dense(y, args...)
256+
vmap_multithreaded!(f, y, Val{true}(), args...)
257+
else
258+
map!(f, y, args...)
259+
end
242260
end
243261

244262
# generic fallbacks

0 commit comments

Comments
 (0)