@@ -66,15 +66,29 @@ struct RNG <: AbstractRNG
66
66
end
67
67
68
68
const GLOBAL_RNGS = Dict ()
69
- function global_rng (A:: AbstractGPUArray )
70
- dev = GPUArrays. device (A)
69
+ function global_rng (AT:: Type{<:AbstractGPUArray} , dev)
71
70
get! (GLOBAL_RNGS, dev) do
72
71
N = GPUArrays. threads (dev)
73
- state = similar (A, NTuple{4 , UInt32}, N)
74
- copyto! (state, [ntuple (i-> rand (UInt32), 4 ) for i= 1 : N])
75
- RNG (state)
72
+ AT = Base. typename (AT). wrapper
73
+ state = AT {NTuple{4, UInt32}} (undef, N)
74
+ rng = RNG (state)
75
+ Random. seed! (rng)
76
+ rng
76
77
end
77
78
end
79
+ global_rng (A:: AT ) where {AT <: AbstractGPUArray } = global_rng (AT, GPUArrays. device (A))
80
+
81
+ make_seed (rng:: RNG ) = make_seed (rng, rand (UInt))
82
+ function make_seed (rng:: RNG , n:: Integer )
83
+ rand (MersenneTwister (n), UInt32, sizeof (rng. state)÷ sizeof (UInt32))
84
+ end
85
+
86
+ Random. seed! (rng:: RNG ) = Random. seed! (rng, make_seed (rng))
87
+ Random. seed! (rng:: RNG , seed:: Integer ) = Random. seed! (rng, make_seed (rng, seed))
88
+ function Random. seed! (rng:: RNG , seed:: Vector{UInt32} )
89
+ copyto! (rng. state, reinterpret (NTuple{4 , UInt32}, seed))
90
+ return
91
+ end
78
92
79
93
function Random. rand! (rng:: RNG , A:: AbstractGPUArray{T} ) where T <: Number
80
94
gpu_call (A, rng. state) do ctx, a, randstates
0 commit comments