@@ -28,8 +28,14 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
28
28
# resolve ambiguities
29
29
Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
30
30
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31
+ # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
31
32
Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
32
33
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34
+ # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35
+ Base. mapreduce (f, op, A:: AnyGPUArray ;
36
+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37
+ Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38
+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
33
39
34
40
function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
35
41
# figure out the destination container type by looking at the initializer element,
@@ -40,7 +46,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
40
46
(ET === Union{} || ET === Any) &&
41
47
error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
42
48
43
- init = neutral_element (op, ET)
49
+ init = AK . neutral_element (op, ET)
44
50
else
45
51
ET = typeof (init)
46
52
end
@@ -85,14 +91,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
85
91
end
86
92
end
87
93
88
- Base. any (A:: AnyGPUArray{Bool} ) = mapreduce (identity, | , A)
89
- Base. all (A:: AnyGPUArray{Bool} ) = mapreduce (identity, & , A)
94
+ Base. any (A:: AnyGPUArray{Bool} ) = AK . any (identity, A)
95
+ Base. all (A:: AnyGPUArray{Bool} ) = AK . all (identity, A)
90
96
91
- Base. any (f:: Function , A:: AnyGPUArray ) = mapreduce (f, | , A)
92
- Base. all (f:: Function , A:: AnyGPUArray ) = mapreduce (f, & , A)
97
+ Base. any (f:: Function , A:: AnyGPUArray ) = AK . any (f , A)
98
+ Base. all (f:: Function , A:: AnyGPUArray ) = AK . all (f , A)
93
99
94
100
Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
95
- mapreduce (pred, Base . add_sum, A; init= init , dims= dims)
101
+ AK . count (pred, A; init, dims= dims isa Colon ? nothing : dims)
96
102
97
103
# avoid calling into `initarray!`
98
104
for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
@@ -101,7 +107,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
101
107
fname! = Symbol (fname, ' !' )
102
108
@eval begin
103
109
Base.$ (fname!)(f:: Function , r:: AnyGPUArray , A:: AnyGPUArray{T} ) where T =
104
- GPUArrays. mapreducedim! (f, $ (op), r, A; init= neutral_element ($ (op), T))
110
+ GPUArrays. mapreducedim! (f, $ (op), r, A; init= AK . neutral_element ($ (op), T))
105
111
end
106
112
end
107
113
0 commit comments