Skip to content

Commit fa8f65e

Browse files
committed
Use an AbstractRNG to control dispatch of rand!
1 parent 1b42b98 commit fa8f65e

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

src/random.jl

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
## device interface
2+
13
function TausStep(z::Unsigned, S1::Integer, S2::Integer, S3::Integer, M::Unsigned)
24
b = (((z << S1) z) >> S2)
35
return (((z & M) << S3) b)
@@ -8,7 +10,6 @@ LCGStep(z::Unsigned, A::Unsigned, C::Unsigned) = A * z + C
810
make_rand_num(::Type{Float64}, tmp) = 2.3283064365387e-10 * Float64(tmp)
911
make_rand_num(::Type{Float32}, tmp) = 2.3283064f-10 * Float32(tmp)
1012

11-
1213
function next_rand(::Type{FT}, state::NTuple{4, T}) where {FT, T <: Unsigned}
1314
state = (
1415
TausStep(state[1], Cint(13), Cint(19), Cint(12), T(4294967294)),
@@ -42,22 +43,32 @@ function gpu_rand(::Type{T}, state, randstate::AbstractVector{NTuple{4, UInt32}}
4243
return to_number_range(f, T)
4344
end
4445

45-
let rand_state_dict = Dict()
46-
global cached_state, clear_cache
47-
clear_cache() = (empty!(rand_state_dict); return)
48-
function cached_state(x)
49-
dev = GPUArrays.device(x)
50-
get!(rand_state_dict, dev) do
51-
N = GPUArrays.threads(dev)
52-
res = similar(x, NTuple{4, UInt32}, N)
53-
copyto!(res, [ntuple(i-> rand(UInt32), 4) for i=1:N])
54-
res
55-
end
46+
47+
## host interface
48+
49+
struct RNG <: AbstractRNG
50+
state::GPUArray{NTuple{4,UInt32},1}
51+
52+
function RNG(A::GPUArray)
53+
dev = GPUArrays.device(A)
54+
N = GPUArrays.threads(dev)
55+
state = similar(A, NTuple{4, UInt32}, N)
56+
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
57+
new(state)
58+
end
59+
end
60+
61+
const GLOBAL_RNGS = Dict()
62+
function global_rng(A::GPUArray)
63+
dev = GPUArrays.device(A)
64+
get!(GLOBAL_RNGS, dev) do
65+
RNG(A)
5666
end
5767
end
68+
5869
function Random.rand!(A::GPUArray{T}) where T <: Number
59-
rstates = cached_state(A)
60-
gpu_call(A, (rstates, A,)) do state, randstates, a
70+
rng = global_rng(A)
71+
gpu_call(A, (rng.state, A,)) do state, randstates, a
6172
idx = linear_index(state)
6273
idx > length(a) && return
6374
@inbounds a[idx] = gpu_rand(T, state, randstates)
@@ -66,14 +77,13 @@ function Random.rand!(A::GPUArray{T}) where T <: Number
6677
A
6778
end
6879

69-
Random.rand(X::Type{<: GPUArray}, i::Integer...) = rand(X, Float32, i...)
70-
Random.rand(X::Type{<: GPUArray}, size::NTuple{N, Int}) where N = rand(X, Float32, size...)
71-
Random.rand(X::Type{<: GPUArray{T}}, i::Integer...) where T = rand(X, T, i...)
72-
Random.rand(X::Type{<: GPUArray{T}}, size::NTuple{N, Int}) where {T, N} = rand(X, T, size...)
73-
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Integer}) where {T, N} = rand(X, T, size...)
74-
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Int}) where {T, N} = rand(X, T, size...)
80+
gpurand(::Type{X}, dims...) where {X<:GPUArray} = randn!(X(dims...))
7581

76-
function Random.rand(X::Type{<: GPUArray}, ::Type{ET}, size::Integer...) where ET
77-
A = similar(X, ET, size)
78-
rand!(A)
79-
end
82+
# FIXME: the following definitions are not part of the Random stdlib API
83+
Random.rand(X::Type{<: GPUArray}, i::Integer...) = rand!(X{Float32}(i...))
84+
Random.rand(X::Type{<: GPUArray}, size::NTuple{N, Int}) where N = rand!(X{Float32}(size...))
85+
Random.rand(X::Type{<: GPUArray{T}}, i::Integer...) where {T} = rand!(X(i...))
86+
Random.rand(X::Type{<: GPUArray{T}}, size::NTuple{N, Int}) where {T,N} = rand!(X(size...))
87+
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Integer}) where {T,N} = rand!(X(size...))
88+
Random.rand(X::Type{<: GPUArray{T, N}}, size::NTuple{N, Int}) where {T,N} = rand!(X(size...))
89+
Random.rand(X::Type{<: GPUArray}, ::Type{T}, size::Integer...) where {T} = rand!(similar(X, T, size))

0 commit comments

Comments
 (0)