Skip to content

Commit a310fc7

Browse files
committed
AK mapreduce
TODO: - Fix extrema - Dims arrays - Multi-input
1 parent 7afba33 commit a310fc7

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/host/mapreduce.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
2828
# resolve ambiguities
2929
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;
3030
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
31+
# dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
3132
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
3233
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
34+
# dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35+
Base.mapreduce(f, op, A::AnyGPUArray;
36+
dims=:, init=nothing) = AK.mapreduce(f, op, A; init, dims=dims isa Colon ? nothing : dims)
37+
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle};
38+
dims=:, init=nothing) = AK.mapreduce(f, op, A; init, dims=dims isa Colon ? nothing : dims)
3339

3440
function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,N,D}
3541
# figure out the destination container type by looking at the initializer element,
@@ -85,14 +91,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
8591
end
8692
end
8793

88-
Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)
89-
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)
94+
Base.any(A::AnyGPUArray{Bool}) = AK.any(identity, A)
95+
Base.all(A::AnyGPUArray{Bool}) = AK.all(identity, A)
9096

91-
Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)
92-
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
97+
Base.any(f::Function, A::AnyGPUArray) = AK.any(f, A)
98+
Base.all(f::Function, A::AnyGPUArray) = AK.all(f, A)
9399

94100
Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) =
95-
mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
101+
AK.count(pred, A; init, dims=dims isa Colon ? nothing : dims)# mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
96102

97103
# avoid calling into `initarray!`
98104
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),

0 commit comments

Comments
 (0)