@@ -46,7 +46,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
46
46
(ET === Union{} || ET === Any) &&
47
47
error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
48
48
49
- init = neutral_element (op, ET)
49
+ init = AK . neutral_element (op, ET)
50
50
else
51
51
ET = typeof (init)
52
52
end
@@ -98,7 +98,7 @@ Base.any(f::Function, A::AnyGPUArray) = AK.any(f, A)
98
98
Base. all (f:: Function , A:: AnyGPUArray ) = AK. all (f, A)
99
99
100
100
Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
101
- AK. count (pred, A; init, dims= dims isa Colon ? nothing : dims)# mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
101
+ AK. count (pred, A; init, dims= dims isa Colon ? nothing : dims)
102
102
103
103
# avoid calling into `initarray!`
104
104
for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
@@ -107,7 +107,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
107
107
fname! = Symbol (fname, ' !' )
108
108
@eval begin
109
109
Base.$ (fname!)(f:: Function , r:: AnyGPUArray , A:: AnyGPUArray{T} ) where T =
110
- GPUArrays. mapreducedim! (f, $ (op), r, A; init= neutral_element ($ (op), T))
110
+ GPUArrays. mapreducedim! (f, $ (op), r, A; init= AK . neutral_element ($ (op), T))
111
111
end
112
112
end
113
113
0 commit comments