Skip to content

Commit 9c696cd

Browse files
committed
Use Metal.jl native rand by default
1 parent 36873f1 commit 9c696cd

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

src/random.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -227,49 +227,31 @@ Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[]
227227
# resolve ambiguities
228228
Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[]
229229

230+
################
231+
# RNG-less API #
232+
################
230233

231234
# GPUArrays in-place
232235
Random.rand!(A::MtlArray) = Random.rand!(mtl_rng(), A)
233236
Random.randn!(A::MtlArray) = Random.randn!(mtl_rng(), A)
234237

235-
# Use MPS random functionality where possible
236-
function Random.rand!(A::MPS.UniformArray)
237-
return Random.rand!(mpsrand_rng(), A)
238-
end
239-
function Random.randn!(A::MPS.NormalArray)
240-
return Random.randn!(mpsrand_rng(), A)
241-
end
242-
243238
# GPUArrays out-of-place
244-
function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode)
245-
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
246-
end
247-
function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
248-
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
249-
end
250239
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
251240
Random.rand!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
252241
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
253242
Random.randn!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
254243

255244
# support all dimension specifications
256-
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
257-
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
258-
end
259-
function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
260-
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
261-
end
262-
263245
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
264246
Random.rand!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
265247
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
266248
Random.randn!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
267249

268250
# untyped out-of-place
269251
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
270-
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
252+
Random.rand!(mtl_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
271253
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
272-
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
254+
Random.randn!(mtl_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
273255

274256
# scalars
275257
rand(T::Type=Float32; storage=SharedStorage) = rand(T, 1; storage)[1]

0 commit comments

Comments
 (0)