@@ -9,28 +9,20 @@ mpsrand_rng() = MPS.default_rng()
99Random. rand! (A:: MtlArray ) = Random. rand! (gpuarrays_rng (), A)
1010Random. randn! (A:: MtlArray ) = Random. randn! (gpuarrays_rng (), A)
1111
12- @inline function can_use_mpsrandom (A:: MtlArray{T} ) where {T}
13- return A. offset * sizeof (T) % 4 == 0 && sizeof (A) % 4 == 0
14- end
15-
1612# Use MPS random functionality where possible
1713function Random. rand! (A:: MPS.UniformArray )
18- rng = can_use_mpsrandom (A) ? mpsrand_rng () : gpuarrays_rng ()
19- return Random. rand! (rng, A)
14+ return Random. rand! (mpsrand_rng (), A)
2015end
2116function Random. randn! (A:: MPS.NormalArray )
22- rng = can_use_mpsrandom (A) ? mpsrand_rng () : gpuarrays_rng ()
23- return Random. randn! (rng, A)
17+ return Random. randn! (mpsrand_rng (), A)
2418end
2519
2620# GPUArrays out-of-place
2721function rand (T:: MPS.UniformType , dims:: Dims ; storage= DefaultStorageMode)
28- rng = prod (dims) * sizeof (T) % 4 == 0 ? mpsrand_rng () : gpuarrays_rng ()
29- return Random. rand! (rng, MtlArray {T,length(dims),storage} (undef, dims... ))
22+ return Random. rand! (mpsrand_rng (), MtlArray {T,length(dims),storage} (undef, dims... ))
3023end
3124function randn (T:: MPS.NormalType , dims:: Dims ; storage= DefaultStorageMode)
32- rng = prod (dims) * sizeof (T) % 4 == 0 ? mpsrand_rng () : gpuarrays_rng ()
33- return Random. randn! (rng, MtlArray {T,length(dims),storage} (undef, dims... ))
25+ return Random. randn! (mpsrand_rng (), MtlArray {T,length(dims),storage} (undef, dims... ))
3426end
3527rand (T:: Type , dims:: Dims ; storage= DefaultStorageMode) =
3628 Random. rand! (gpuarrays_rng (), MtlArray {T,length(dims),storage} (undef, dims... ))
@@ -39,12 +31,10 @@ randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
3931
4032# support all dimension specifications
4133function rand (T:: MPS.UniformType , dim1:: Integer , dims:: Integer... ; storage= DefaultStorageMode)
42- rng = (dim1 * prod (dims) * sizeof (T)) % 4 == 0 ? mpsrand_rng () : gpuarrays_rng ()
43- return Random. rand! (rng, MtlArray {T,length(dims) + 1,storage} (undef, dim1, dims... ))
34+ return Random. rand! (mpsrand_rng (), MtlArray {T,length(dims) + 1,storage} (undef, dim1, dims... ))
4435end
4536function randn (T:: MPS.NormalType , dim1:: Integer , dims:: Integer... ; storage= DefaultStorageMode)
46- rng = (dim1 * prod (dims) * sizeof (T)) % 4 == 0 ? mpsrand_rng () : gpuarrays_rng ()
47- return Random. randn! (rng, MtlArray {T,length(dims) + 1,storage} (undef, dim1, dims... ))
37+ return Random. randn! (mpsrand_rng (), MtlArray {T,length(dims) + 1,storage} (undef, dim1, dims... ))
4838end
4939
5040rand (T:: Type , dim1:: Integer , dims:: Integer... ; storage= DefaultStorageMode) =
@@ -59,8 +49,8 @@ randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
5949 Random. randn! (mpsrand_rng (), MtlArray {Float32,length(dims) + 1,storage} (undef, dim1, dims... ))
6050
6151# scalars
62- rand (T:: Type = Float32; storage= SharedStorage) = rand (T, 4 ; storage)[1 ]
63- randn (T:: Type = Float32; storage= SharedStorage) = randn (T, 4 ; storage)[1 ]
52+ rand (T:: Type = Float32; storage= SharedStorage) = rand (T, 1 ; storage)[1 ]
53+ randn (T:: Type = Float32; storage= SharedStorage) = randn (T, 1 ; storage)[1 ]
6454
6555# seeding
6656function seed! (seed= Base. rand (UInt64))
0 commit comments