@@ -66,8 +66,7 @@ function global_rng(A::GPUArray)
66
66
end
67
67
end
68
68
69
- function Random. rand! (A:: GPUArray{T} ) where T <: Number
70
- rng = global_rng (A)
69
+ function Random. rand! (rng:: RNG , A:: GPUArray{T} ) where T <: Number
71
70
gpu_call (A, (rng. state, A,)) do state, randstates, a
72
71
idx = linear_index (state)
73
72
idx > length (a) && return
@@ -77,13 +76,4 @@ function Random.rand!(A::GPUArray{T}) where T <: Number
77
76
A
78
77
end
79
78
80
- gpurand (:: Type{X} , dims... ) where {X<: GPUArray } = randn! (X (dims... ))
81
-
82
- # FIXME : the following definitions are not part of the Random stdlib API
83
- Random. rand (X:: Type{<: GPUArray} , i:: Integer... ) = rand! (X {Float32} (i... ))
84
- Random. rand (X:: Type{<: GPUArray} , size:: NTuple{N, Int} ) where N = rand! (X {Float32} (size... ))
85
- Random. rand (X:: Type{<: GPUArray{T}} , i:: Integer... ) where {T} = rand! (X (i... ))
86
- Random. rand (X:: Type{<: GPUArray{T}} , size:: NTuple{N, Int} ) where {T,N} = rand! (X (size... ))
87
- Random. rand (X:: Type{<: GPUArray{T, N}} , size:: NTuple{N, Integer} ) where {T,N} = rand! (X (size... ))
88
- Random. rand (X:: Type{<: GPUArray{T, N}} , size:: NTuple{N, Int} ) where {T,N} = rand! (X (size... ))
89
- Random. rand (X:: Type{<: GPUArray} , :: Type{T} , size:: Integer... ) where {T} = rand! (similar (X, T, size))
79
+ Random. rand! (A:: GPUArray ) = rand! (global_rng (A), A)
0 commit comments