Skip to content

Commit e155e6b

Browse files
authored
Merge pull request #167 from JuliaGPU/tb/curand
Use an AbstractRNG to control dispatch of rand!, remove rand
2 parents 6b804ea + c7bca9c commit e155e6b

File tree

7 files changed

+35
-41
lines changed

7 files changed

+35
-41
lines changed

.gitlab-ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ stages:
77

88
include:
99
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v0/common.yml'
10-
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v0/test_v0.7.yml'
1110
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v0/test_v1.0.yml'
1211
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v0/test_dev.yml'
1312
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v0/postprocess_coverage.yml'

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ os:
55
- osx
66
dist: trusty
77
julia:
8-
- 0.7
98
- 1.0
109
- nightly
1110
matrix:

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
julia 0.7-alpha
1+
julia 1.0
22
StaticArrays
33
FFTW
44
FillArrays

appveyor.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
environment:
22
matrix:
3-
- julia_version: 0.7
43
- julia_version: 1.0
54
- julia_version: latest
65

src/random.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## device interface
2+
3+
# hybrid Tausworthe and Linear Congruent generator from
4+
# https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch37.html
5+
16
function TausStep(z::Unsigned, S1::Integer, S2::Integer, S3::Integer, M::Unsigned)
27
b = (((z << S1) z) >> S2)
38
return (((z & M) << S3) b)
@@ -8,7 +13,6 @@ LCGStep(z::Unsigned, A::Unsigned, C::Unsigned) = A * z + C
813
make_rand_num(::Type{Float64}, tmp) = 2.3283064365387e-10 * Float64(tmp)
914
make_rand_num(::Type{Float32}, tmp) = 2.3283064f-10 * Float32(tmp)
1015

11-
1216
function next_rand(::Type{FT}, state::NTuple{4, T}) where {FT, T <: Unsigned}
1317
state = (
1418
TausStep(state[1], Cint(13), Cint(19), Cint(12), T(4294967294)),
@@ -42,22 +46,31 @@ function gpu_rand(::Type{T}, state, randstate::AbstractVector{NTuple{4, UInt32}}
4246
return to_number_range(f, T)
4347
end
4448

45-
let rand_state_dict = Dict()
46-
global cached_state, clear_cache
47-
clear_cache() = (empty!(rand_state_dict); return)
48-
function cached_state(x)
49-
dev = GPUArrays.device(x)
50-
get!(rand_state_dict, dev) do
51-
N = GPUArrays.threads(dev)
52-
res = similar(x, NTuple{4, UInt32}, N)
53-
copyto!(res, [ntuple(i-> rand(UInt32), 4) for i=1:N])
54-
res
55-
end
49+
50+
## host interface
51+
52+
struct RNG <: AbstractRNG
53+
state::GPUArray{NTuple{4,UInt32},1}
54+
55+
function RNG(A::GPUArray)
56+
dev = GPUArrays.device(A)
57+
N = GPUArrays.threads(dev)
58+
state = similar(A, NTuple{4, UInt32}, N)
59+
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
60+
new(state)
61+
end
62+
end
63+
64+
const GLOBAL_RNGS = Dict()
65+
function global_rng(A::GPUArray)
66+
dev = GPUArrays.device(A)
67+
get!(GLOBAL_RNGS, dev) do
68+
RNG(A)
5669
end
5770
end
58-
function Random.rand!(A::GPUArray{T}) where T <: Number
59-
rstates = cached_state(A)
60-
gpu_call(A, (rstates, A,)) do state, randstates, a
71+
72+
function Random.rand!(rng::RNG, A::GPUArray{T}) where T <: Number
73+
gpu_call(A, (rng.state, A,)) do state, randstates, a
6174
idx = linear_index(state)
6275
idx > length(a) && return
6376
@inbounds a[idx] = gpu_rand(T, state, randstates)
@@ -66,14 +79,4 @@ function Random.rand!(A::GPUArray{T}) where T <: Number
6679
A
6780
end
6881

69-
Random.rand(X::Type{<: GPUArray}, i::Integer...) = rand(X, Float32, i...)
70-
Random.rand(X::Type{<: GPUArray}, size::NTuple{N, Int}) where N = rand(X, Float32, size...)
71-
Random.rand(X::Type{<: GPUArray{T}}, i::Integer...) where T = rand(X, T, i...)
72-
Random.rand(X::Type{<: GPUArray{T}}, size::NTuple{N, Int}) where {T, N} = rand(X, T, size...)
73-
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Integer}) where {T, N} = rand(X, T, size...)
74-
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Int}) where {T, N} = rand(X, T, size...)
75-
76-
function Random.rand(X::Type{<: GPUArray}, ::Type{ET}, size::Integer...) where ET
77-
A = similar(X, ET, size)
78-
rand!(A)
79-
end
82+
Random.rand!(A::GPUArray) = rand!(global_rng(A), A)

src/testsuite/indexing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ function test_indexing(AT)
1313
end
1414
@testset "multi dim, sliced setindex" begin
1515
x = fill(AT{T}, T(0), (10, 10, 10, 10))
16-
y = rand(AT{T}, 5, 5, 10, 10)
16+
y = AT{T}(5, 5, 10, 10)
17+
rand!(y)
1718
x[2:6, 2:6, :, :] = y
1819
x[2:6, 2:6, :, :] == y
1920
end

src/testsuite/random.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
11
function test_random(AT)
22
@testset "Random" begin
33
@testset "rand" begin # uniform
4-
for T in (Float32, Float64, Int64, Int32)
5-
@test length(rand(AT{T,1}, (4,))) == 4
6-
@test length(rand(AT{T}, (4,))) == 4
7-
@test length(rand(AT{T}, 4)) == 4
8-
@test eltype(rand(AT, 4)) == Float32
9-
@test length(rand(AT, T, 4)) == 4
10-
@test length(rand(AT{T,2}, (4,5))) == 20
11-
@test length(rand(AT, T, 4, 5)) == 20
12-
A = rand(AT{T,2}, (2,2))
4+
for T in (Float32, Float64, Int64, Int32), d in (2, (2,2))
5+
A = AT{T}(d)
136
B = copy(A)
14-
@test all(A .== B)
7+
rand!(A)
158
rand!(B)
169
@test !any(A .== B)
1710
end

0 commit comments

Comments
 (0)