Skip to content

Commit 816c105

Browse files
committed
Widen mapreduce signatures to support wrappers.
1 parent caef400 commit 816c105

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/host/mapreduce.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}
44

55
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
66
# argument `init` value to avoid eager initialization of `R` (if set to something).
7-
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
7+
mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArrayOrBroadcasted;
88
init=nothing) = error("Not implemented") # COV_EXCL_LINE
99
# resolve ambiguities
10-
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
11-
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)
10+
Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
11+
Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)
1212

1313
neutral_element(op, T) =
1414
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
@@ -24,7 +24,7 @@ neutral_element(::typeof(Base.min), T) = typemax(T)
2424
neutral_element(::typeof(Base.max), T) = typemin(T)
2525

2626
# resolve ambiguities
27-
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
27+
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;
2828
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
2929
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
3030
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
@@ -68,24 +68,24 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
6868
end
6969
end
7070

71-
Base.any(A::AbstractGPUArray{Bool}) = mapreduce(identity, |, A)
72-
Base.all(A::AbstractGPUArray{Bool}) = mapreduce(identity, &, A)
71+
Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)
72+
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)
7373

74-
Base.any(f::Function, A::AbstractGPUArray) = mapreduce(f, |, A)
75-
Base.all(f::Function, A::AbstractGPUArray) = mapreduce(f, &, A)
74+
Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)
75+
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
7676

77-
Base.count(pred::Function, A::AbstractGPUArray; dims=:) =
77+
Base.count(pred::Function, A::AnyGPUArray; dims=:) =
7878
mapreduce(pred, Base.add_sum, A; init=0, dims=dims)
7979

80-
Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))
80+
Base.:(==)(A::AnyGPUArray, B::AnyGPUArray) = Bool(mapreduce(==, &, A, B))
8181

8282
# avoid calling into `initarray!`
8383
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
8484
(:maximum, :(Base.max)), (:minimum, :(Base.min)),
8585
(:all, :&), (:any, :|)]
8686
fname! = Symbol(fname, '!')
8787
@eval begin
88-
Base.$(fname!)(f::Function, r::AbstractGPUArray, A::AbstractGPUArray{T}) where T =
88+
Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T =
8989
GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T))
9090
end
9191
end

0 commit comments

Comments
 (0)