Skip to content

Commit c6dd4a6

Browse files
committed
Another attempt at fixing showing, test cases for showing and mapreduce
1 parent 332659b commit c6dd4a6

File tree

4 files changed

+42
-8
lines changed

4 files changed

+42
-8
lines changed

src/abstractarray.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,24 @@ end
5656

5757
## showing
5858

59-
Base.show(io::IO, x::GPUArray) = Base.show(io, Array(x))
60-
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) =
61-
Base.show(io, LinearAlgebra.adjoint(Array(x.parent)))
62-
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) =
63-
Base.show(io, LinearAlgebra.transpose(Array(x.parent)))
64-
65-
Base.show_vector(io::IO, x::GPUArray) = Base.show_vector(io, Array(x))
66-
59+
for (atype, op) in
60+
[(:(GPUArray), :(Array)),
61+
(:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(x.parent)))),
62+
(:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(x.parent))))]
63+
@eval begin
64+
# for display
65+
Base.print_array(io::IO, X::($atype)) =
66+
Base.print_array(io,($op)(X))
67+
68+
# for show
69+
Base._show_nonempty(io::IO, X::($atype), prefix::String) =
70+
Base._show_nonempty(io,($op)(X),prefix)
71+
Base._show_empty(io::IO, X::($atype)) =
72+
Base._show_empty(io,($op)(X))
73+
Base.show_vector(io::IO, v::($atype), args...) =
74+
Base.show_vector(io,($op)(v),args...)
75+
end
76+
end
6777

6878
# memory operations
6979

src/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ end
4949
Runs the entire GPUArrays test suite on array type `AT`
5050
"""
5151
function test(AT::Type{<:GPUArray})
52+
GPUArrays.allowscalar(false)
5253
TestSuite.test_construction(AT)
5354
TestSuite.test_gpuinterface(AT)
5455
TestSuite.test_indexing(AT)

src/testsuite/io.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,25 @@ function test_io(AT)
33
@testset "showing" begin
44
io = IOBuffer()
55
A = AT(Int64[1])
6+
B = AT(Int64[1 2;3 4]) # vectors and non-vector arrays showing
7+
# are handled differently in base/arrayshow.jl
68

79
show(io, MIME("text/plain"), A)
810
seekstart(io)
911
@test String(take!(io)) == "1-element $AT{Int64,1}:\n 1"
1012

13+
show(io, A)
14+
seekstart(io)
15+
@test String(take!(io)) == "[1]"
16+
17+
show(io, MIME("text/plain"), B)
18+
seekstart(io)
19+
@test String(take!(io)) == "2×2 $AT{Int64,2}:\n 1 2\n 3 4"
20+
21+
show(io, B)
22+
seekstart(io)
23+
@test String(take!(io)) == "[1 2; 3 4]"
24+
1125
show(io, MIME("text/plain"), A')
1226
seekstart(io)
1327
msg = String(take!(io)) # the printing of Adjoint depends on global state

src/testsuite/mapreduce.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ function test_mapreduce(AT)
2121
x = T(y)
2222
@test sum(y, dims = 2) Array(sum(x, dims = 2))
2323
@test sum(y, dims = 1) Array(sum(x, dims = 1))
24+
25+
y = rand(range, N, N)
26+
x = T(y)
27+
_zero = zero(ET)
28+
_addone(z) = z + one(ET)
29+
@test mapreduce(_addone, +, y; dims = 2, init = _zero)
30+
Array(mapreduce(_addone, +, x; dims = 2, init = _zero))
31+
@test mapreduce(_addone, +, y; init = _zero)
32+
mapreduce(_addone, +, x; init = _zero)
2433
end
2534
end
2635
@testset "sum maximum minimum prod" begin

0 commit comments

Comments
 (0)