@@ -4,11 +4,11 @@ const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}
4
4
5
5
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
6
6
# argument `init` value to avoid eager initialization of `R` (if set to something).
7
- mapreducedim! (f, op, R:: AbstractGPUArray , A:: AbstractArrayOrBroadcasted ;
7
+ mapreducedim! (f, op, R:: AnyGPUArray , A:: AbstractArrayOrBroadcasted ;
8
8
init= nothing ) = error (" Not implemented" ) # COV_EXCL_LINE
9
9
# resolve ambiguities
10
- Base. mapreducedim! (f, op, R:: AbstractGPUArray , A:: AbstractArray ) = mapreducedim! (f, op, R, A)
11
- Base. mapreducedim! (f, op, R:: AbstractGPUArray , A:: Broadcast.Broadcasted ) = mapreducedim! (f, op, R, A)
10
+ Base. mapreducedim! (f, op, R:: AnyGPUArray , A:: AbstractArray ) = mapreducedim! (f, op, R, A)
11
+ Base. mapreducedim! (f, op, R:: AnyGPUArray , A:: Broadcast.Broadcasted ) = mapreducedim! (f, op, R, A)
12
12
13
13
neutral_element (op, T) =
14
14
error (""" GPUArrays.jl needs to know the neutral element for your operator `$op `.
@@ -24,7 +24,7 @@ neutral_element(::typeof(Base.min), T) = typemax(T)
24
24
neutral_element (:: typeof (Base. max), T) = typemin (T)
25
25
26
26
# resolve ambiguities
27
- Base. mapreduce (f, op, A:: AbstractGPUArray , As:: AbstractArrayOrBroadcasted... ;
27
+ Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
28
28
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
29
29
Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
30
30
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
@@ -68,24 +68,24 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
68
68
end
69
69
end
70
70
71
- Base. any (A:: AbstractGPUArray {Bool} ) = mapreduce (identity, | , A)
72
- Base. all (A:: AbstractGPUArray {Bool} ) = mapreduce (identity, & , A)
71
+ Base. any (A:: AnyGPUArray {Bool} ) = mapreduce (identity, | , A)
72
+ Base. all (A:: AnyGPUArray {Bool} ) = mapreduce (identity, & , A)
73
73
74
- Base. any (f:: Function , A:: AbstractGPUArray ) = mapreduce (f, | , A)
75
- Base. all (f:: Function , A:: AbstractGPUArray ) = mapreduce (f, & , A)
74
+ Base. any (f:: Function , A:: AnyGPUArray ) = mapreduce (f, | , A)
75
+ Base. all (f:: Function , A:: AnyGPUArray ) = mapreduce (f, & , A)
76
76
77
- Base. count (pred:: Function , A:: AbstractGPUArray ; dims= :) =
77
+ Base. count (pred:: Function , A:: AnyGPUArray ; dims= :) =
78
78
mapreduce (pred, Base. add_sum, A; init= 0 , dims= dims)
79
79
80
- Base.:(== )(A:: AbstractGPUArray , B:: AbstractGPUArray ) = Bool (mapreduce (== , & , A, B))
80
+ Base.:(== )(A:: AnyGPUArray , B:: AnyGPUArray ) = Bool (mapreduce (== , & , A, B))
81
81
82
82
# avoid calling into `initarray!`
83
83
for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
84
84
(:maximum , :(Base. max)), (:minimum , :(Base. min)),
85
85
(:all , :& ), (:any , :| )]
86
86
fname! = Symbol (fname, ' !' )
87
87
@eval begin
88
- Base.$ (fname!)(f:: Function , r:: AbstractGPUArray , A:: AbstractGPUArray {T} ) where T =
88
+ Base.$ (fname!)(f:: Function , r:: AnyGPUArray , A:: AnyGPUArray {T} ) where T =
89
89
GPUArrays. mapreducedim! (f, $ (op), r, A; init= neutral_element ($ (op), T))
90
90
end
91
91
end
0 commit comments