Skip to content

Commit e10a9b9

Browse files
authored
Merge pull request #145 from ssz66666/master
Fix `Base.mapreduce` to match the new signature in 0.7/1.0
2 parents d98ca89 + 9664691 commit e10a9b9

File tree

4 files changed

+69
-22
lines changed

4 files changed

+69
-22
lines changed

src/abstractarray.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,24 @@ end
5656

5757
## showing
5858

59-
Base.show(io::IO, x::GPUArray) = Base.show(io, Array(x))
60-
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) =
61-
Base.show(io, LinearAlgebra.adjoint(Array(x.parent)))
62-
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) =
63-
Base.show(io, LinearAlgebra.transpose(Array(x.parent)))
64-
65-
Base.show_vector(io::IO, x::GPUArray) = Base.show_vector(io, Array(x))
66-
59+
for (atype, op) in
60+
[(:(GPUArray), :(Array)),
61+
(:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(parent(x))))),
62+
(:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(parent(x)))))]
63+
@eval begin
64+
# for display
65+
Base.print_array(io::IO, X::($atype)) =
66+
Base.print_array(io,($op)(X))
67+
68+
# for show
69+
Base._show_nonempty(io::IO, X::($atype), prefix::String) =
70+
Base._show_nonempty(io,($op)(X),prefix)
71+
Base._show_empty(io::IO, X::($atype)) =
72+
Base._show_empty(io,($op)(X))
73+
Base.show_vector(io::IO, v::($atype), args...) =
74+
Base.show_vector(io,($op)(v),args...)
75+
end
76+
end
6777

6878
# memory operations
6979

src/mapreduce.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
# functions in base implemented with a direct loop need to be overloaded to use mapreduce
44

55

6-
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A)
7-
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A)
8-
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A))
6+
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false)
7+
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true)
8+
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))
99

10-
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B))
10+
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true))
1111

1212
# hack to get around of fetching the first element of the GPUArray
1313
# as a startvalue, which is a bit complicated with the current reduce implementation
1414
function startvalue(f, T)
15-
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)")
15+
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)")
1616
end
1717
startvalue(::typeof(+), T) = zero(T)
1818
startvalue(::typeof(Base.add_sum), T) = zero(T)
@@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(
5050
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
5151
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
5252

53-
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N}
53+
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N}
54+
mapreduce_impl(f, op, init.data, A, dims)
55+
end
56+
57+
function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N}
5458
OT = gpu_promote_type(op, T)
5559
v0 = startvalue(op, OT) # TODO do this better
56-
mapreduce(f, op, v0, A)
60+
acc_mapreduce(f, op, v0, A, ())
5761
end
58-
function acc_mapreduce end
59-
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number)
60-
acc_mapreduce(f, op, v0, A, (B, C))
62+
63+
function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N}
64+
acc_mapreduce(f, op, nt.init, A, ())
6165
end
62-
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray)
63-
acc_mapreduce(f, op, v0, A, (B,))
66+
67+
function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N}
68+
Base._mapreduce_dim(f, op, nt, A, dims)
6469
end
65-
function Base.mapreduce(f, op, v0, A::GPUArray)
66-
acc_mapreduce(f, op, v0, A, ())
70+
71+
function acc_mapreduce end
72+
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init)
73+
acc_mapreduce(f, op, init, A, (B, C))
74+
end
75+
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init)
76+
acc_mapreduce(f, op, init, A, (B,))
6777
end
6878

6979
@generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N

src/testsuite/io.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,29 @@ function test_io(AT)
33
@testset "showing" begin
44
io = IOBuffer()
55
A = AT(Int64[1])
6+
B = AT(Int64[1 2;3 4]) # vectors and non-vector arrays showing
7+
# are handled differently in base/arrayshow.jl
68

79
show(io, MIME("text/plain"), A)
810
seekstart(io)
911
@test String(take!(io)) == "1-element $AT{Int64,1}:\n 1"
1012

13+
show(io, A)
14+
seekstart(io)
15+
msg = String(take!(io)) # result of e.g. `print` differs on 32bit and 64bit machines
16+
# due to different definition of `Int` type
17+
# print([1]) shows as [1] on 64bit but Int64[1] on 32bit
18+
@test msg == "[1]" || msg == "Int64[1]"
19+
20+
show(io, MIME("text/plain"), B)
21+
seekstart(io)
22+
@test String(take!(io)) == "2×2 $AT{Int64,2}:\n 1 2\n 3 4"
23+
24+
show(io, B)
25+
seekstart(io)
26+
msg = String(take!(io))
27+
@test msg == "[1 2; 3 4]" || msg == "Int64[1 2; 3 4]"
28+
1129
show(io, MIME("text/plain"), A')
1230
seekstart(io)
1331
msg = String(take!(io)) # the printing of Adjoint depends on global state

src/testsuite/mapreduce.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ function test_mapreduce(AT)
2121
x = T(y)
2222
@test sum(y, dims = 2) Array(sum(x, dims = 2))
2323
@test sum(y, dims = 1) Array(sum(x, dims = 1))
24+
25+
y = rand(range, N, N)
26+
x = T(y)
27+
_zero = zero(ET)
28+
_addone(z) = z + one(ET)
29+
@test mapreduce(_addone, +, y; dims = 2, init = _zero)
30+
Array(mapreduce(_addone, +, x; dims = 2, init = _zero))
31+
@test mapreduce(_addone, +, y; init = _zero)
32+
mapreduce(_addone, +, x; init = _zero)
2433
end
2534
end
2635
@testset "sum maximum minimum prod" begin

0 commit comments

Comments
 (0)