Skip to content

Commit 712790f

Browse files
authored
Merge pull request #153 from JuliaGPU/tb/mapreduce
Fix tests for CuArrays
2 parents e10a9b9 + 8575e50 commit 712790f

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

src/testsuite/mapreduce.jl

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,30 @@ function test_mapreduce(AT)
66
range = ET <: Integer ? (ET(-2):ET(2)) : ET
77
@testset "mapreducedim" begin
88
for N in (2, 10)
9-
y = rand(range, N, N)
10-
x = T(y)
11-
@test sum(y, dims = 2) Array(sum(x, dims = 2))
12-
@test sum(y, dims = 1) Array(sum(x, dims = 1))
13-
@test sum(y, dims = (1, 2)) Array(sum(x, dims = (1, 2)))
9+
@test compare(x -> sum(x, dims=2), AT, rand(range, N, N))
10+
@test compare(x -> sum(x, dims=1), AT, rand(range, N, N))
11+
@test compare(x -> sum(x, dims=(1, 2)), AT, rand(range, N, N))
1412

15-
y = rand(range, N, 10)
16-
x = T(y)
17-
@test sum(y, dims = 2) Array(sum(x, dims = 2))
18-
@test sum(y, dims = 1) Array(sum(x, dims = 1))
13+
@test compare(x -> sum(x, dims=2), AT, rand(range, N, 10))
14+
@test compare(x -> sum(x, dims=1), AT, rand(range, N, 10))
1915

20-
y = rand(range, 10, N)
21-
x = T(y)
22-
@test sum(y, dims = 2) Array(sum(x, dims = 2))
23-
@test sum(y, dims = 1) Array(sum(x, dims = 1))
16+
@test compare(x -> sum(x, dims=2), AT, rand(range, 10, N))
17+
@test compare(x -> sum(x, dims=1), AT, rand(range, 10, N))
2418

25-
y = rand(range, N, N)
26-
x = T(y)
2719
_zero = zero(ET)
28-
_addone(z) = z + one(ET)
29-
@test mapreduce(_addone, +, y; dims = 2, init = _zero)
30-
Array(mapreduce(_addone, +, x; dims = 2, init = _zero))
31-
@test mapreduce(_addone, +, y; init = _zero)
32-
mapreduce(_addone, +, x; init = _zero)
20+
_addone(z) = z + one(z)
21+
@test compare(x->mapreduce(_addone, +, x; dims = 2),
22+
AT, rand(range, N, N))
23+
@test compare(x->mapreduce(_addone, +, x; dims = 2, init = _zero),
24+
AT, rand(range, N, N))
3325
end
3426
end
3527
@testset "sum maximum minimum prod" begin
3628
for dims in ((4048,), (1024,1024), (77,), (1923,209))
37-
Ac = rand(range, dims)
38-
A = T(Ac)
39-
@test sum(A) sum(Ac)
40-
ET <: Complex || @test maximum(A) maximum(Ac)
41-
ET <: Complex || @test minimum(A) minimum(Ac)
42-
@test prod(A) prod(Ac)
29+
@test compare(sum, AT, rand(range, dims))
30+
@test compare(prod, AT, rand(range, dims))
31+
ET <: Complex || @test compare(maximum, AT,rand(range, dims))
32+
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
4333
end
4434
end
4535
end

0 commit comments

Comments
 (0)