Skip to content

Commit 59be95a

Browse files
committed
change Base.mapreduce to match the new function signature introduced in Julia 0.7/1.0, namely moving dims and v0 into kwargs
1 parent 3f45485 commit 59be95a

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

src/mapreduce.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
# functions in base implemented with a direct loop need to be overloaded to use mapreduce
44

55

6-
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A)
7-
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A)
8-
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A))
6+
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false)
7+
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true)
8+
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))
99

10-
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B))
10+
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true))
1111

1212
# hack to get around of fetching the first element of the GPUArray
1313
# as a startvalue, which is a bit complicated with the current reduce implementation
1414
function startvalue(f, T)
15-
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)")
15+
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)")
1616
end
1717
startvalue(::typeof(+), T) = zero(T)
1818
startvalue(::typeof(Base.add_sum), T) = zero(T)
@@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(
5050
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
5151
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
5252

53-
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N}
53+
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N}
54+
mapreduce_impl(f, op, init.data, A, dims)
55+
end
56+
57+
function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N}
5458
OT = gpu_promote_type(op, T)
5559
v0 = startvalue(op, OT) # TODO do this better
56-
mapreduce(f, op, v0, A)
60+
acc_mapreduce(f, op, v0, A, ())
5761
end
58-
function acc_mapreduce end
59-
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number)
60-
acc_mapreduce(f, op, v0, A, (B, C))
62+
63+
function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N}
64+
acc_mapreduce(f, op, nt.init, A, ())
6165
end
62-
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray)
63-
acc_mapreduce(f, op, v0, A, (B,))
66+
67+
function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N}
68+
Base._mapreduce_dim(f, op, nt, A, dims)
6469
end
65-
function Base.mapreduce(f, op, v0, A::GPUArray)
66-
acc_mapreduce(f, op, v0, A, ())
70+
71+
function acc_mapreduce end
72+
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init)
73+
acc_mapreduce(f, op, init, A, (B, C))
74+
end
75+
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init)
76+
acc_mapreduce(f, op, init, A, (B,))
6777
end
6878

6979
@generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N

0 commit comments

Comments
 (0)