Skip to content

Commit ecde50b

Browse files
committed
Allow v(map)reduce to be more flexible in arguments it accepts
1 parent a3344fd commit ecde50b

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/mapreduce.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ vec_vreduce(op, v::Vec{1}) = VectorizationBase.extractelement(v, 0)
1414
a
1515
end
1616

17-
function mapreduce_simple(f::F, op::OP, args::Vararg{DenseNativeArray,A}) where {F,OP,A}
17+
function mapreduce_simple(f::F, op::OP, args::Vararg{AbstractArray,A}) where {F,OP,A}
1818
ptrargs = ntuple(a -> pointer(args[a]), Val(A))
1919
N = length(first(args))
2020
iszero(N) && throw("Length of vector is 0!")
@@ -32,7 +32,10 @@ end
3232
3333
Vectorized version of `mapreduce`. Applies `f` to each element of the arrays `A`, and reduces the result with `op`.
3434
"""
35-
@inline function vmapreduce(f::F, op::OP, arg1::DenseArray{T}, args::Vararg{DenseArray{T},A}) where {F,OP,T<:NativeTypes,A}
35+
@inline function vmapreduce(f::F, op::OP, arg1::AbstractArray{T}, args::Vararg{AbstractArray{T},A}) where {F,OP,T<:NativeTypes,A}
36+
if !(check_args(args1, args...) && all_dense(arg1, args...))
37+
return mapreduce(f, op, arg1, args...)
38+
end
3639
N = length(arg1)
3740
iszero(A) || @assert all(length.(args) .== N)
3841
W = VectorizationBase.pick_vector_width(T)
@@ -43,7 +46,7 @@ Vectorized version of `mapreduce`. Applies `f` to each element of the arrays `A`
4346
_vmapreduce(f, op, V, N, T, arg1, args...)
4447
end
4548
end
46-
@inline function _vmapreduce(f::F, op::OP, ::StaticInt{W}, N, ::Type{T}, args::Vararg{DenseArray{<:NativeTypes},A}) where {F,OP,A,W,T}
49+
@inline function _vmapreduce(f::F, op::OP, ::StaticInt{W}, N, ::Type{T}, args::Vararg{AbstractArray{<:NativeTypes},A}) where {F,OP,A,W,T}
4750
ptrargs = VectorizationBase.zero_offsets.(stridedpointer.(args))
4851
if N 4W
4952
index = VectorizationBase.Unroll{1,1,4,1,W,0x0000000000000000}((Zero(),)); i = 4W
@@ -79,6 +82,9 @@ Vectorized version of `reduce`. Reduces the array `A` using the operator `op`.
7982

8083
for (op, init) in zip((:+, :max, :min), (:zero, :typemin, :typemax))
8184
@eval @inline function vreduce(::typeof($op), arg; dims = nothing)
85+
if !(check_args(arg) && all_dense(arg))
86+
return reduce($op, arg, dims = dims)
87+
end
8288
isnothing(dims) && return _vreduce($op, arg)
8389
isone(ndims(arg)) && return [_vreduce($op, arg)]
8490
@assert length(dims) == 1

0 commit comments

Comments
 (0)