Skip to content

Commit 9443b0f

Browse files
committed
Remove mapreduce implementation, expect users to provide mapreducedim! impl.
1 parent 91fbbac commit 9443b0f

File tree

4 files changed

+174
-179
lines changed

4 files changed

+174
-179
lines changed

src/host/mapreduce.jl

Lines changed: 51 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,62 @@
11
# map-reduce
22

3-
Base.any(A::AbstractGPUArray{Bool}) = mapreduce(identity, |, A; init = false)
4-
Base.all(A::AbstractGPUArray{Bool}) = mapreduce(identity, &, A; init = true)
5-
6-
Base.any(f::Function, A::AbstractGPUArray) = mapreduce(f, |, A; init = false)
7-
Base.all(f::Function, A::AbstractGPUArray) = mapreduce(f, &, A; init = true)
8-
Base.count(pred::Function, A::AbstractGPUArray) = Int(mapreduce(pred, +, A; init = 0))
9-
10-
Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B; init = true))
11-
12-
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = acc_mapreduce(==, &, true, A, adjoint(A))
13-
14-
# hack to get around of fetching the first element of the AbstractGPUArray
15-
# as a startvalue, which is a bit complicated with the current reduce implementation
16-
_initerror(f) = error("Please supply a neutral element for $f. E.g: mapreduce(f, $f, A; init = 1)")
17-
startvalue(f, T) = _initerror(f)
18-
for op = (+, Base.add_sum, *, Base.mul_prod, max, min)
19-
@eval startvalue(::typeof($op), ::Type{Any}) = _initerror($op)
20-
end
21-
22-
startvalue(::typeof(+), T) = zero(T)
23-
startvalue(::typeof(Base.add_sum), T) = zero(T)
24-
startvalue(::typeof(*), T) = one(T)
25-
startvalue(::typeof(Base.mul_prod), T) = one(T)
26-
27-
startvalue(::typeof(max), T) = typemin(T)
28-
startvalue(::typeof(min), T) = typemax(T)
29-
30-
# TODO mirror base
31-
32-
if Int === Int32
33-
const SmallSigned = Union{Int8,Int16}
34-
const SmallUnsigned = Union{UInt8,UInt16}
35-
else
36-
const SmallSigned = Union{Int8,Int16,Int32}
37-
const SmallUnsigned = Union{UInt8,UInt16,Int}
38-
end
39-
40-
const CommonReduceResult = Union{UInt64,UInt128,Int64,Int128,Float16,Float32,Float64}
41-
const WidenReduceResult = Union{SmallSigned, SmallUnsigned}
42-
43-
44-
# TODO widen and support Int64 and use Base.r_promote_type
45-
gpu_promote_type(op, ::Type{T}) where {T} = T
46-
gpu_promote_type(op, ::Type{T}) where {T<: WidenReduceResult} = T
47-
gpu_promote_type(::typeof(+), ::Type{T}) where {T<: WidenReduceResult} = T
48-
gpu_promote_type(::typeof(*), ::Type{T}) where {T<: WidenReduceResult} = T
49-
gpu_promote_type(::typeof(Base.add_sum), ::Type{T}) where {T<:WidenReduceResult} = typeof(Base.add_sum(zero(T), zero(T)))
50-
gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:WidenReduceResult} = typeof(Base.mul_prod(one(T), one(T)))
51-
gpu_promote_type(::typeof(+), ::Type{T}) where {T<:Number} = typeof(zero(T)+zero(T))
52-
gpu_promote_type(::typeof(*), ::Type{T}) where {T<:Number} = typeof(one(T)*one(T))
53-
gpu_promote_type(::typeof(Base.add_sum), ::Type{T}) where {T<:Number} = typeof(Base.add_sum(zero(T), zero(T)))
54-
gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(Base.mul_prod(one(T), one(T)))
55-
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
56-
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
57-
gpu_promote_type(::typeof(abs), ::Type{Complex{T}}) where {T} = T
58-
gpu_promote_type(::typeof(abs2), ::Type{Complex{T}}) where {T} = T
59-
60-
import Base.Broadcast: Broadcasted
61-
const GPUSrcArray = Union{Broadcasted{<:AbstractGPUArrayStyle}, <:AbstractGPUArray}
62-
63-
function Base.mapreduce(f::Function, op::Function, A::GPUSrcArray; dims = :, init...)
64-
mapreduce_impl(f, op, init.data, A, dims)
65-
end
66-
67-
function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUSrcArray, ::Colon)
68-
OT = gpu_promote_type(op, gpu_promote_type(f, eltype(A)))
69-
v0 = startvalue(op, OT) # TODO do this better
70-
acc_mapreduce(f, op, v0, A)
71-
end
72-
73-
function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUSrcArray, ::Colon)
74-
acc_mapreduce(f, op, nt.init, A)
75-
end
76-
77-
function mapreduce_impl(f, op, nt, A::GPUSrcArray, dims)
78-
Base._mapreduce_dim(f, op, nt, A, dims)
79-
end
3+
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
4+
# argument `init` value to avoid eager initialization of `R` (if set to something).
5+
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray, init=nothing) = error("Not implemented") # COV_EXCL_LINE
6+
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
7+
8+
neutral_element(op, T) =
9+
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
10+
Please pass it as an explicit argument to (if possible), or register it
11+
globally your operator by defining `GPUArrays.neutral_element(::typeof($op), T)`.""")
12+
neutral_element(::typeof(Base.:(|)), T) = zero(T)
13+
neutral_element(::typeof(Base.:(+)), T) = zero(T)
14+
neutral_element(::typeof(Base.add_sum), T) = zero(T)
15+
neutral_element(::typeof(Base.:(&)), T) = one(T)
16+
neutral_element(::typeof(Base.:(*)), T) = one(T)
17+
neutral_element(::typeof(Base.mul_prod), T) = one(T)
18+
neutral_element(::typeof(Base.min), T) = typemax(T)
19+
neutral_element(::typeof(Base.max), T) = typemin(T)
20+
21+
function Base.mapreduce(f, op, A::AbstractGPUArray; dims=:, init=nothing)
22+
# figure out the destination container type by looking at the initializer element,
23+
# or by relying on inference to reason through the map and reduce functions.
24+
if init === nothing
25+
ET = Base.promote_op(f, eltype(A))
26+
ET = Base.promote_op(op, ET, ET)
27+
(ET === Union{} || ET === Any) &&
28+
error("mapreduce cannot figure the output element type, please pass an explicit init value")
29+
30+
init = neutral_element(op, ET)
31+
else
32+
ET = typeof(init)
33+
end
8034

81-
function acc_mapreduce end
82-
function Base.mapreduce(f, op, A::GPUSrcArray, B::GPUSrcArray, C::Number; init)
83-
acc_mapreduce(f, op, init, A, B, C)
84-
end
85-
function Base.mapreduce(f, op, A::GPUSrcArray, B::GPUSrcArray; init)
86-
acc_mapreduce(f, op, init, A, B)
87-
end
35+
sz = size(A)
36+
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
37+
R = similar(A, ET, red)
38+
mapreducedim!(f, op, R, A, init)
8839

89-
@generated function mapreducedim_kernel(ctx::AbstractKernelContext, f, op, R, A, range::NTuple{N, Any}) where N
90-
types = (range.parameters...,)
91-
indices = ntuple(i-> Symbol("I_$i"), N)
92-
Iexpr = ntuple(i-> :(I[$i]), N)
93-
body = :(@inbounds R[$(Iexpr...)] = op(R[$(Iexpr...)], f(A[$(indices...)])))
94-
for i = N:-1:1
95-
idxsym = indices[i]
96-
if types[i] == Nothing
97-
body = quote
98-
$idxsym = I[$i]
99-
$body
100-
end
101-
else
102-
rsym = Symbol("r_$i")
103-
body = quote
104-
$(rsym) = range[$i]
105-
for $idxsym in Int(first($rsym)):Int(last($rsym))
106-
$body
107-
end
108-
end
109-
end
110-
body
111-
end
112-
quote
113-
I = @cartesianidx R ctx
114-
$body
115-
return
40+
if dims==Colon()
41+
@allowscalar R[]
42+
else
43+
R
11644
end
11745
end
11846

119-
function Base._mapreducedim!(f, op, R::AbstractGPUArray, A::GPUSrcArray)
120-
range = ifelse.(length.(axes(R)) .== 1, axes(A), nothing)
121-
gpu_call(mapreducedim_kernel, f, op, R, A, range; target=R)
122-
return R
123-
end
124-
125-
@inline simple_broadcast_index(A::AbstractArray, i...) = @inbounds A[i...]
126-
@inline simple_broadcast_index(x, i...) = x
47+
Base.any(A::AbstractGPUArray{Bool}) = mapreduce(identity, |, A)
48+
Base.all(A::AbstractGPUArray{Bool}) = mapreduce(identity, &, A)
12749

128-
for i = 0:10
129-
args = ntuple(x-> Symbol("arg_", x), i)
130-
fargs = ntuple(x-> :(simple_broadcast_index($(args[x]), cartesian_global_index...)), i)
131-
@eval begin
132-
# http://developer.amd.com/resources/articles-whitepapers/opencl-optimization-case-study-simple-reductions/
133-
function reduce_kernel(ctx::AbstractKernelContext, f, op, v0::T, A, ::Val{LMEM}, result, $(args...)) where {T, LMEM}
134-
tmp_local = @LocalMemory(ctx, T, LMEM)
135-
global_index = linear_index(ctx)
136-
acc = v0
137-
# # Loop sequentially over chunks of input vector
138-
@inbounds while global_index <= length(A)
139-
cartesian_global_index = Tuple(CartesianIndices(axes(A))[global_index])
140-
@inbounds element = f(A[cartesian_global_index...], $(fargs...))
141-
acc = op(acc, element)
142-
global_index += global_size(ctx)
143-
end
144-
# Perform parallel reduction
145-
local_index = threadidx(ctx) - 1
146-
@inbounds tmp_local[local_index + 1] = acc
147-
synchronize_threads(ctx)
50+
Base.any(f::Function, A::AbstractGPUArray) = mapreduce(f, |, A)
51+
Base.all(f::Function, A::AbstractGPUArray) = mapreduce(f, &, A)
52+
Base.count(pred::Function, A::AbstractGPUArray) = mapreduce(pred, +, A; init = 0)
14853

149-
offset = blockdim(ctx) ÷ 2
150-
@inbounds while offset > 0
151-
if (local_index < offset)
152-
other = tmp_local[local_index + offset + 1]
153-
mine = tmp_local[local_index + 1]
154-
tmp_local[local_index + 1] = op(mine, other)
155-
end
156-
synchronize_threads(ctx)
157-
offset = offset ÷ 2
158-
end
159-
if local_index == 0
160-
@inbounds result[blockidx(ctx)] = tmp_local[1]
161-
end
162-
return
163-
end
164-
end
54+
Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))
16555

166-
end
56+
# avoid calling into `initarray!``
57+
Base.sum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.add_sum, R, A)
58+
Base.prod!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.mul_prod, R, A)
59+
Base.maximum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(max, R, A)
60+
Base.minimum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(min, R, A)
16761

168-
function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest...) where {OT}
169-
blocks = 80
170-
threads = 256
171-
if length(A) <= blocks * threads
172-
args = zip(convert_to_cpu(A), convert_to_cpu.(rest)...)
173-
return mapreduce(x-> f(x...), op, args, init = v0)
174-
end
175-
out = similar(A, OT, (blocks,))
176-
fill!(out, v0)
177-
gpu_call(reduce_kernel, f, op, v0, A, Val{threads}(), out, rest...;
178-
target=out, threads=threads, blocks=blocks)
179-
reduce(op, Array(out))
180-
end
62+
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A))

src/reference.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,4 +284,11 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
284284
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
285285
reshape(reinterpret(T, A.data), size)
286286

287+
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::AbstractArray, init=nothing)
288+
if init !== nothing
289+
fill!(R, init)
290+
end
291+
@allowscalar Base.mapreducedim!(f, op, R.data, A)
292+
end
293+
287294
end

test/testsuite/broadcasting.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ function broadcasting(AT)
5656
@testset "Adjoint and Transpose" begin
5757
A = AT(rand(ET, N))
5858
A' .= ET(2)
59-
@test all(x->x==ET(2), A)
59+
@test all(isequal(ET(2)'), A)
6060
transpose(A) .= ET(1)
61-
@test all(x->x==ET(1), A)
61+
@test all(isequal(ET(1)), A)
6262
end
6363

6464
############

test/testsuite/mapreduce.jl

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,118 @@
11
function test_mapreduce(AT)
2+
@testset "mapreducedim! $ET" for ET in supported_eltypes() begin
3+
T = AT{ET}
4+
range = ET <: Real ? (ET(1):ET(10)) : ET
5+
for (sz,red) in [(10,)=>(1,), (10,10)=>(1,1), (10,10,10)=>(1,1,1), (10,10,10)=>(10,10,10),
6+
(10,10,10)=>(1,10,10), (10,10,10)=>(10,1,10), (10,10,10)=>(10,10,1)]
7+
@test compare((A,R)->Base.mapreducedim!(identity, +, R, A), AT, rand(range, sz), zeros(ET, red))
8+
@test compare((A,R)->Base.mapreducedim!(identity, *, R, A), AT, rand(range, sz), ones(ET, red))
9+
@test compare((A,R)->Base.mapreducedim!(x->x+x, +, R, A), AT, rand(range, sz), zeros(ET, red))
10+
return
11+
end
12+
end
13+
end
14+
15+
@testset "reducedim! $ET" for ET in supported_eltypes() begin
16+
T = AT{ET}
17+
range = ET <: Real ? (ET(1):ET(10)) : ET
18+
for (sz,red) in [(10,)=>(1,), (10,10)=>(1,1), (10,10,10)=>(1,1,1), (10,10,10)=>(10,10,10),
19+
(10,10,10)=>(1,10,10), (10,10,10)=>(10,1,10), (10,10,10)=>(10,10,1)]
20+
@test compare((A,R)->Base.reducedim!(+, R, A), AT, rand(range, sz), zeros(ET, red))
21+
@test compare((A,R)->Base.reducedim!(*, R, A), AT, rand(range, sz), ones(ET, red))
22+
end
23+
end
24+
end
25+
26+
@testset "mapreduce $ET" for ET in supported_eltypes() begin
27+
T = AT{ET}
28+
range = ET <: Real ? (ET(1):ET(10)) : ET
29+
for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[],
30+
(10,)=>:, (10,10)=>:, (10,10,10)=>:,
31+
(10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3]]
32+
@test compare(A->mapreduce(identity, +, A; dims=dims, init=zero(ET)), AT, rand(range, sz))
33+
@test compare(A->mapreduce(identity, *, A; dims=dims, init=one(ET)), AT, rand(range, sz))
34+
@test compare(A->mapreduce(x->x+x, +, A; dims=dims, init=zero(ET)), AT, rand(range, sz))
35+
end
36+
end
37+
end
38+
39+
@testset "reduce $ET" for ET in supported_eltypes() begin
40+
T = AT{ET}
41+
range = ET <: Real ? (ET(1):ET(10)) : ET
42+
for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[],
43+
(10,)=>:, (10,10)=>:, (10,10,10)=>:,
44+
(10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3]]
45+
@test compare(A->reduce(+, A; dims=dims, init=zero(ET)), AT, rand(range, sz))
46+
@test compare(A->reduce(*, A; dims=dims, init=one(ET)), AT, rand(range, sz))
47+
end
48+
end
49+
end
50+
51+
@testset "sum prod minimum maximum $ET" for ET in supported_eltypes() begin
52+
T = AT{ET}
53+
range = ET <: Real ? (ET(1):ET(10)) : ET
54+
for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[],
55+
(10,)=>:, (10,10)=>:, (10,10,10)=>:,
56+
(10,10,10)=>[1], (10,10,10)=>[2], (10,10,10)=>[3]]
57+
@test compare(A->sum(A), AT, rand(range, sz))
58+
@test compare(A->sum(abs, A), AT, rand(range, sz))
59+
@test compare(A->sum(A; dims=dims), AT, rand(range, sz))
60+
@test compare(A->prod(A), AT, rand(range, sz))
61+
@test compare(A->prod(abs, A), AT, rand(range, sz))
62+
@test compare(A->prod(A; dims=dims), AT, rand(range, sz))
63+
if !(ET <: Complex)
64+
@test compare(A->minimum(A), AT, rand(range, sz))
65+
@test compare(A->minimum(x->x*x, A), AT, rand(range, sz))
66+
@test compare(A->minimum(A; dims=dims), AT, rand(range, sz))
67+
@test compare(A->maximum(A), AT, rand(range, sz))
68+
@test compare(A->maximum(x->x*x, A), AT, rand(range, sz))
69+
@test compare(A->maximum(A; dims=dims), AT, rand(range, sz))
70+
end
71+
end
72+
OT = isbitstype(widen(ET)) ? widen(ET) : ET
73+
for (sz,red) in [(10,)=>(1,), (10,10)=>(1,1), (10,10,10)=>(1,1,1), (10,10,10)=>(10,10,10),
74+
(10,10,10)=>(1,10,10), (10,10,10)=>(10,1,10), (10,10,10)=>(10,10,1)]
75+
if !(ET <: Complex)
76+
@test compare((A,R)->minimum!(R, A), AT, rand(range, sz), fill(typemax(ET), red))
77+
@test compare((A,R)->maximum!(R, A), AT, rand(range, sz), fill(typemin(ET), red))
78+
end
79+
end
80+
# smaller-scale test to avoid very large values and roundoff issues
81+
for (sz,red) in [(2,)=>(1,), (2,2)=>(1,1), (2,2,2)=>(1,1,1), (2,2,2)=>(2,2,2),
82+
(2,2,2)=>(1,2,2), (2,2,2)=>(2,1,2), (2,2,2)=>(2,2,1)]
83+
@test compare((A,R)->sum!(R, A), AT, rand(range, sz), zeros(OT, red))
84+
@test compare((A,R)->prod!(R, A), AT, rand(range, sz), ones(OT, red))
85+
end
86+
end
87+
end
88+
89+
@testset "any all count ==" begin
90+
for Ac in ([false, false], [false, true], [true, true],
91+
[false false; false false], [false true; false false],
92+
[true true; false false], [true true; true true])
93+
@test compare(A->any(A), AT, Ac)
94+
@test compare(A->all(A), AT, Ac)
95+
end
96+
for Ac in ([1, 1], [1, 2], [2, 2],
97+
[1 1; 1 1], [1 2; 1 1],
98+
[2 2; 1 1], [2 2; 2 2])
99+
@test compare(A->any(iseven, A), AT, Ac)
100+
@test compare(A->all(iseven, A), AT, Ac)
101+
@test compare(A->count(iseven, A), AT, Ac)
102+
103+
A = AT(Ac)
104+
@test A == copy(A)
105+
@test A !== copy(A)
106+
@test A == deepcopy(A)
107+
@test A !== deepcopy(A)
108+
109+
B = similar(A)
110+
@allowscalar B[1] = 3
111+
@test A != B
112+
end
113+
end
114+
115+
# old tests: can be removed, but left in here for a while to ensure the new impl works
2116
@testset "mapreduce" begin
3117
for ET in supported_eltypes()
4118
T = AT{ET}
@@ -36,14 +150,6 @@ function test_mapreduce(AT)
36150
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
37151
end
38152
end
39-
40-
@testset "broadcasted arrays" begin
41-
for dims in ((4048,), (1024,1024), (77,), (1923,209))
42-
@test compare(x->mapreduce(z -> z + one(z), +,
43-
Broadcast.Broadcasted(+, (x, x));
44-
init = zero(ET)), AT, rand(range, dims))
45-
end
46-
end
47153
end
48154
end
49155
@testset "any all ==" begin

0 commit comments

Comments
 (0)