Skip to content

Commit c2841ec

Browse files
authored
Merge pull request #101 from schmrlng/pull-request/b7b88010
Added all to mapreduce.jl, tests for any/all and rand/rand!
2 parents 1ac8cc6 + b7b8801 commit c2841ec

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

src/mapreduce.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import Base: any, count, countnz
1+
import Base: any, all, count, countnz
22

33
#############################
44
# reduce
55
# functions in base implemented with a direct loop need to be overloaded to use mapreduce
66
any(pred, A::GPUArray) = Bool(mapreduce(pred, |, Int32(0), A))
7+
all(pred, A::GPUArray) = Bool(mapreduce(pred, &, Int32(1), A))
78
count(pred, A::GPUArray) = Int(mapreduce(pred, +, UInt32(0), A))
89
countnz(A::GPUArray) = Int(mapreduce(x-> x != 0, +, UInt32(0), A))
910
countnz(A::GPUArray, dim) = Int(mapreducedim(x-> x != 0, +, UInt32(0), A, dim))

src/testsuite/mapreduce.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,13 @@ function run_mapreduce(Typ)
3838
end
3939
end
4040
end
41+
@testset "any all" begin
42+
for Ac in ([false, false], [false, true], [true, true])
43+
A = Typ(Ac)
44+
@test typeof(A) == Typ{Bool,1}
45+
@test any(A) == any(Ac)
46+
@test all(A) == all(Ac)
47+
end
48+
end
4149
end
4250
end

src/testsuite/random.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using GPUArrays
2+
using Base.Test, GPUArrays.TestSuite
3+
4+
function run_random(Typ)
5+
@testset "Random" begin
6+
@testset "rand" begin # uniform
7+
for T in (Float32, Float64)
8+
@test length(rand(Typ{T,1}, (4,))) == 4
9+
@test length(rand(Typ, T, 4)) == 4
10+
@test length(rand(Typ{T,2}, (4,5))) == 20
11+
@test length(rand(Typ, T, 4, 5)) == 20
12+
A = rand(Typ{T,2}, (2,2))
13+
B = copy(A)
14+
@test all(A .== B)
15+
rand!(B)
16+
@test !any(A .== B)
17+
end
18+
end
19+
end
20+
end

src/testsuite/testsuite.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("mapreduce.jl")
4141
include("base.jl")
4242
include("indexing.jl")
4343
# include("vector.jl")
44+
include("random.jl")
4445

4546
function supported_eltypes()
4647
(Float32, Float64, Int32, Int64, Complex64, Complex128)
@@ -60,6 +61,7 @@ function run_tests(Typ)
6061
run_linalg(Typ)
6162
run_mapreduce(Typ)
6263
run_indexing(Typ)
64+
run_random(Typ)
6365
end
6466

6567
export against_base, run_tests, supported_eltypes

0 commit comments

Comments
 (0)