|
3 | 3 | # functions in base implemented with a direct loop need to be overloaded to use mapreduce
|
4 | 4 |
|
5 | 5 |
|
6 |
| -Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A) |
7 |
| -Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A) |
8 |
| -Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A)) |
| 6 | +Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false) |
| 7 | +Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true) |
| 8 | +Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0)) |
9 | 9 |
|
10 |
| -Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B)) |
| 10 | +Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true)) |
11 | 11 |
|
12 | 12 | # hack to get around of fetching the first element of the GPUArray
|
13 | 13 | # as a startvalue, which is a bit complicated with the current reduce implementation
|
14 | 14 | function startvalue(f, T)
|
15 |
| - error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)") |
| 15 | + error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)") |
16 | 16 | end
|
17 | 17 | startvalue(::typeof(+), T) = zero(T)
|
18 | 18 | startvalue(::typeof(Base.add_sum), T) = zero(T)
|
@@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(
|
50 | 50 | gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
|
51 | 51 | gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
|
52 | 52 |
|
53 |
| -function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N} |
| 53 | +function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N} |
| 54 | + mapreduce_impl(f, op, init.data, A, dims) |
| 55 | +end |
| 56 | + |
| 57 | +function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N} |
54 | 58 | OT = gpu_promote_type(op, T)
|
55 | 59 | v0 = startvalue(op, OT) # TODO do this better
|
56 |
| - mapreduce(f, op, v0, A) |
| 60 | + acc_mapreduce(f, op, v0, A, ()) |
57 | 61 | end
|
58 |
| -function acc_mapreduce end |
59 |
| -function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number) |
60 |
| - acc_mapreduce(f, op, v0, A, (B, C)) |
| 62 | + |
| 63 | +function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N} |
| 64 | + acc_mapreduce(f, op, nt.init, A, ()) |
61 | 65 | end
|
62 |
| -function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray) |
63 |
| - acc_mapreduce(f, op, v0, A, (B,)) |
| 66 | + |
| 67 | +function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N} |
| 68 | + Base._mapreduce_dim(f, op, nt, A, dims) |
64 | 69 | end
|
65 |
| -function Base.mapreduce(f, op, v0, A::GPUArray) |
66 |
| - acc_mapreduce(f, op, v0, A, ()) |
| 70 | + |
| 71 | +function acc_mapreduce end |
| 72 | +function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init) |
| 73 | + acc_mapreduce(f, op, init, A, (B, C)) |
| 74 | +end |
| 75 | +function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init) |
| 76 | + acc_mapreduce(f, op, init, A, (B,)) |
67 | 77 | end
|
68 | 78 |
|
69 | 79 | @generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N
|
|
0 commit comments