Skip to content

Commit 4463977

Browse files
authored
Merge pull request #242 from JuliaGPU/tb/seed
Add ability to seed the RNG, and do so initially.
2 parents 0627899 + ca0dc9c commit 4463977

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/host/random.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,29 @@ struct RNG <: AbstractRNG
6666
end
6767

6868
const GLOBAL_RNGS = Dict()
69-
function global_rng(A::AbstractGPUArray)
70-
dev = GPUArrays.device(A)
69+
function global_rng(AT::Type{<:AbstractGPUArray}, dev)
7170
get!(GLOBAL_RNGS, dev) do
7271
N = GPUArrays.threads(dev)
73-
state = similar(A, NTuple{4, UInt32}, N)
74-
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
75-
RNG(state)
72+
AT = Base.typename(AT).wrapper
73+
state = AT{NTuple{4, UInt32}}(undef, N)
74+
rng = RNG(state)
75+
Random.seed!(rng)
76+
rng
7677
end
7778
end
79+
global_rng(A::AT) where {AT <: AbstractGPUArray} = global_rng(AT, GPUArrays.device(A))
80+
81+
make_seed(rng::RNG) = make_seed(rng, rand(UInt))
82+
function make_seed(rng::RNG, n::Integer)
83+
rand(MersenneTwister(n), UInt32, sizeof(rng.state)÷sizeof(UInt32))
84+
end
85+
86+
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed(rng))
87+
Random.seed!(rng::RNG, seed::Integer) = Random.seed!(rng, make_seed(rng, seed))
88+
function Random.seed!(rng::RNG, seed::Vector{UInt32})
89+
copyto!(rng.state, reinterpret(NTuple{4, UInt32}, seed))
90+
return
91+
end
7892

7993
function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
8094
gpu_call(A, rng.state) do ctx, a, randstates

test/testsuite/random.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ function test_random(AT)
99
rand!(A)
1010
rand!(B)
1111
@test !any(A .== B)
12+
13+
rng = GPUArrays.global_rng(A)
14+
Random.seed!(rng)
15+
Random.seed!(rng, 1)
16+
rand!(rng, A)
17+
Random.seed!(rng, 1)
18+
rand!(rng, B)
19+
@test all(A .== B)
1220
end
1321
end
1422
end

0 commit comments

Comments
 (0)