Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions src/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,26 +227,27 @@ Random.randn(rng::RNG, T::Type=Float32) = Random.randn(rng, T, 1)[]
# resolve ambiguities
Random.randn(rng::RNG, T::Random.BitFloatType) = Random.randn(rng, T, 1)[]

############################################
# RNG-less API #
# Use MPS for uniformly distributed RNG, #
# but native rand for normally distributed #
# to work around JuliaGPU/Metal.jl#474 #
############################################

# GPUArrays in-place
Random.rand!(A::MtlArray) = Random.rand!(mtl_rng(), A)
Random.randn!(A::MtlArray) = Random.randn!(mtl_rng(), A)

# Use MPS random functionality where possible
# Use MPS random functionality where possible for uniformly distributed RNG
function Random.rand!(A::MPS.UniformArray)
return Random.rand!(mpsrand_rng(), A)
end
function Random.randn!(A::MPS.NormalArray)
return Random.randn!(mpsrand_rng(), A)
end

# GPUArrays out-of-place
function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode)
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
end
function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
end

rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(mtl_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Expand All @@ -256,9 +257,6 @@ randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end
function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end

rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mtl_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
Expand All @@ -269,7 +267,7 @@ randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
Random.randn!(mtl_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
rand(T::Type=Float32; storage=SharedStorage) = rand(T, 1; storage)[1]
Expand Down
37 changes: 37 additions & 0 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,40 @@ end
rand!(rng, c)
randn!(rng, c)
end

@testset "randn NaN (Issue #474)" begin
SEED = 1234
N=100000000

# randn!
let X = Metal.zeros(Float32, N)
Metal.seed!(SEED)
randn!(X)
nans = findall(isnan, Array(X))
@test isempty(nans)
end

# randn(T, dims::Dims)
let
Metal.seed!(SEED)
X = Metal.randn(Float32, Dims(N))
nans = findall(isnan, Array(X))
@test isempty(nans)
end

# randn(T, dim1::Integer, dims...)
let
Metal.seed!(SEED)
X = Metal.randn(Float32, N)
nans = findall(isnan, Array(X))
@test isempty(nans)
end

# randn(dim1)
let
Metal.seed!(SEED)
X = Metal.randn(N)
nans = findall(isnan, Array(X))
@test isempty(nans)
end
end