|
| 1 | +using Random |
| 2 | +using Metal: DefaultStorageMode |
| 3 | + |
| 4 | +""" |
| 5 | + MPS.RNG() |
| 6 | +
|
| 7 | +A random number generator using `rand()` in a device kernel. |
| 8 | +""" |
| 9 | +mutable struct RNG <: AbstractRNG |
| 10 | + device::MTLDevice |
| 11 | + uniformInteger::MPSMatrixRandomPhilox |
| 12 | + uniformFloat32::MPSMatrixRandomPhilox |
| 13 | + normalFloat32::MPSMatrixRandomPhilox |
| 14 | +end |
| 15 | + |
| 16 | + |
| 17 | +make_seed() = Base.rand(RandomDevice(), UInt) |
| 18 | + |
| 19 | +function RNG(device::MTLDevice, seed::Integer) |
| 20 | + seed = seed%UInt |
| 21 | + RNG(device, |
| 22 | + MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()), |
| 23 | + MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)), |
| 24 | + MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),) |
| 25 | +end |
| 26 | +@autoreleasepool RNG(seed::Integer) = RNG(device(), seed) |
| 27 | +RNG(device::MTLDevice) = RNG(device, make_seed()) |
| 28 | + |
| 29 | +@autoreleasepool RNG() = RNG(device(), make_seed()) |
| 30 | + |
| 31 | +Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32)) |
| 32 | + |
| 33 | +@autoreleasepool function Random.seed!(rng::RNG, seed::Integer) |
| 34 | + rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()) |
| 35 | + rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)) |
| 36 | + rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)) |
| 37 | + return rng |
| 38 | +end |
| 39 | + |
| 40 | +Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) |
| 41 | + |
| 42 | +const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() |
| 43 | +@autoreleasepool function default_rng() |
| 44 | + dev = device() |
| 45 | + get!(GLOBAL_RNGs, dev) do |
| 46 | + RNG(dev) |
| 47 | + end |
| 48 | +end |
| 49 | + |
| 50 | +const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64] |
| 51 | +const UniformType = Union{[Type{T} for T in UniformTypes]...} |
| 52 | +const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} |
| 53 | +@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} |
| 54 | + isempty(A) && return A |
| 55 | + _mpsmat_rand!(rng.uniformInteger, A, UInt32) |
| 56 | + return A |
| 57 | +end |
| 58 | + |
| 59 | +@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32}) |
| 60 | + isempty(A) && return A |
| 61 | + _mpsmat_rand!(rng.uniformFloat32, A, Float32) |
| 62 | + return A |
| 63 | +end |
| 64 | + |
| 65 | +const NormalType = Type{Float32} |
| 66 | +const NormalArray = MtlArray{<:Float32} |
| 67 | +@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32}) |
| 68 | + isempty(A) && return A |
| 69 | + _mpsmat_rand!(rng.normalFloat32, A, Float32) |
| 70 | + return A |
| 71 | +end |
| 72 | + |
| 73 | +# CPU arrays |
| 74 | +function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N} |
| 75 | + isempty(A) && return A |
| 76 | + B = MtlArray{T,N,SharedStorage}(undef, size(A)) |
| 77 | + rand!(rng, B) |
| 78 | + copyto!(A, unsafe_wrap(Array{T},B)) |
| 79 | + return A |
| 80 | +end |
| 81 | +function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N} |
| 82 | + isempty(A) && return A |
| 83 | + B = MtlArray{T,N,SharedStorage}(undef, size(A)) |
| 84 | + randn!(rng, B) |
| 85 | + copyto!(A, unsafe_wrap(Array{T},B)) |
| 86 | + return A |
| 87 | +end |
| 88 | + |
| 89 | +# Out of place |
| 90 | +Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) = |
| 91 | + Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) |
| 92 | +Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) = |
| 93 | + Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) |
| 94 | + |
| 95 | +# support all dimension specifications |
| 96 | +Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = |
| 97 | + Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) |
| 98 | +Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = |
| 99 | + Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) |
| 100 | + |
| 101 | +# untyped out-of-place |
| 102 | +Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = |
| 103 | + Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) |
| 104 | +Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = |
| 105 | + Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) |
| 106 | + |
| 107 | +# scalars |
| 108 | +Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 4; storage)[1] |
| 109 | +Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 4; storage)[1] |
0 commit comments