@@ -28,8 +28,14 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
28
28
# resolve ambiguities
29
29
Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
30
30
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31
+ # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
31
32
Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
32
33
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34
+ # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35
+ Base. mapreduce (f, op, A:: AnyGPUArray ;
36
+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37
+ Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38
+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
33
39
34
40
function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
35
41
# figure out the destination container type by looking at the initializer element,
@@ -85,14 +91,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
85
91
end
86
92
end
87
93
88
- Base. any (A:: AnyGPUArray{Bool} ) = mapreduce (identity, | , A)
89
- Base. all (A:: AnyGPUArray{Bool} ) = mapreduce (identity, & , A)
94
+ Base. any (A:: AnyGPUArray{Bool} ) = AK . any (identity, A)
95
+ Base. all (A:: AnyGPUArray{Bool} ) = AK . all (identity, A)
90
96
91
- Base. any (f:: Function , A:: AnyGPUArray ) = mapreduce (f, | , A)
92
- Base. all (f:: Function , A:: AnyGPUArray ) = mapreduce (f, & , A)
97
+ Base. any (f:: Function , A:: AnyGPUArray ) = AK . any (f , A)
98
+ Base. all (f:: Function , A:: AnyGPUArray ) = AK . all (f , A)
93
99
94
100
Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
95
- mapreduce (pred, Base. add_sum, A; init= init, dims= dims)
101
+ AK . count (pred, A; init, dims = dims isa Colon ? nothing : dims) # mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
96
102
97
103
# avoid calling into `initarray!`
98
104
for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
0 commit comments