Skip to content

Commit 1dde978

Browse files
Add wrappers for MPSMatrixRandom (#321)
Co-authored-by: Tim Besard <[email protected]>
1 parent 28576b3 commit 1dde978

File tree

7 files changed

+606
-46
lines changed

7 files changed

+606
-46
lines changed

docs/src/usage/array.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
```@meta
44
DocTestSetup = quote
55
using Metal
6+
using GPUArrays
7+
8+
import Random
9+
Random.seed!(1)
10+
11+
Metal.seed!(1)
612
end
713
```
814

@@ -106,3 +112,42 @@ julia> Base.mapreducedim!(identity, +, b, a)
106112
1×1 MtlMatrix{Float32, Metal.PrivateStorage}:
107113
6.0
108114
```
115+
116+
## Random numbers
117+
118+
Base's convenience functions for generating random numbers are available in Metal as well:
119+
120+
```jldoctest
121+
julia> Metal.rand(2)
122+
2-element MtlVector{Float32, Metal.PrivateStorage}:
123+
0.89025915
124+
0.8946847
125+
126+
julia> Metal.randn(Float32, 2, 1)
127+
2×1 MtlMatrix{Float32, Metal.PrivateStorage}:
128+
1.2279074
129+
1.2518331
130+
```
131+
132+
Behind the scenes, these random numbers come from two different generators: one backed by
133+
[Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixrandom?language=objc),
134+
another by using the GPUArrays.jl random methods. Operations on these generators are implemented using methods from the Random
135+
standard library:
136+
137+
```jldoctest
138+
julia> using Random, GPUArrays
139+
140+
julia> a = Random.rand(MPS.default_rng(), Float32, 1)
141+
1-element MtlVector{Float32, Metal.PrivateStorage}:
142+
0.89025915
143+
144+
julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a)
145+
1-element MtlVector{Float32, Metal.PrivateStorage}:
146+
0.0705002
147+
```
148+
149+
!!! note
150+
`MPSMatrixRandom` functionality requires Metal.jl >= v1.4
151+
152+
!!! warning
153+
`Random.rand!(::MPS.RNG, args...)` and `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4.

lib/mps/MPS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ include("kernel.jl")
2828
include("images.jl")
2929
include("matrix.jl")
3030
include("vector.jl")
31+
include("matrixrandom.jl")
3132
include("decomposition.jl")
3233
include("copy.jl")
3334

3435
# integrations
36+
include("random.jl")
3537
include("linalg.jl")
3638

3739
end

lib/mps/matrixrandom.jl

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
@cenum MPSMatrixRandomDistribution::UInt begin
2+
MPSMatrixRandomDistributionDefault = 1
3+
MPSMatrixRandomDistributionUniform = 2
4+
MPSMatrixRandomDistributionNormal = 3
5+
end
6+
7+
#
8+
# matrix random descriptor
9+
#
10+
11+
export MPSMatrixRandomDistributionDescriptor
12+
13+
@objcwrapper immutable=false MPSMatrixRandomDistributionDescriptor <: NSObject
14+
15+
@objcproperties MPSMatrixRandomDistributionDescriptor begin
16+
@autoproperty distributionType::MPSMatrixRandomDistribution
17+
@autoproperty maximum::Float32 setter=setMaximum
18+
@autoproperty mean::Float32 setter=setMean
19+
@autoproperty minimum::Float32 setter=setMimimum
20+
@autoproperty standardDeviation::Float32 setter=setStandardDeviation
21+
end
22+
23+
24+
function MPSMatrixRandomDefaultDistributionDescriptor()
25+
desc = @objc [MPSMatrixRandomDistributionDescriptor defaultDistributionDescriptor]::id{MPSMatrixRandomDistributionDescriptor}
26+
obj = MPSMatrixRandomDistributionDescriptor(desc)
27+
return obj
28+
end
29+
30+
# Default constructor
31+
MPSMatrixRandomDistributionDescriptor() = MPSMatrixRandomDefaultDistributionDescriptor()
32+
33+
function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation)
34+
desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32
35+
standardDeviation:standardDeviation::Float32]::id{MPSMatrixRandomDistributionDescriptor}
36+
obj = MPSMatrixRandomDistributionDescriptor(desc)
37+
return obj
38+
end
39+
40+
function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation, minimum, maximum)
41+
desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32
42+
standardDeviation:standardDeviation::Float32
43+
minimum:minimum::Float32
44+
maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor}
45+
obj = MPSMatrixRandomDistributionDescriptor(desc)
46+
return obj
47+
end
48+
49+
function MPSMatrixRandomUniformDistributionDescriptor(minimum, maximum)
50+
desc = @objc [MPSMatrixRandomDistributionDescriptor uniformDistributionDescriptorWithMinimum:minimum::Float32
51+
maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor}
52+
obj = MPSMatrixRandomDistributionDescriptor(desc)
53+
return obj
54+
end
55+
56+
57+
@objcwrapper immutable=false MPSMatrixRandom <: MPSKernel
58+
59+
@objcproperties MPSMatrixRandom begin
60+
@autoproperty batchSize::NSUInteger
61+
@autoproperty batchStart::NSUInteger
62+
@autoproperty destinationDataType::id{MPSDataType}
63+
@autoproperty distributionType::id{MPSMatrixRandomDistributionDescriptor}
64+
end
65+
66+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationMatrix::MPSMatrix) where {K<:MPSMatrixRandom}
67+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
68+
destinationMatrix:destinationMatrix::id{MPSMatrix}]::Nothing
69+
end
70+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationVector::MPSVector) where {K<:MPSMatrixRandom}
71+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
72+
destinationVector:destinationVector::id{MPSVector}]::Nothing
73+
end
74+
75+
@objcwrapper immutable=false MPSMatrixRandomMTGP32 <: MPSMatrixRandom
76+
@objcwrapper immutable=false MPSMatrixRandomPhilox <: MPSMatrixRandom
77+
78+
for R in [:MPSMatrixRandomMTGP32, :MPSMatrixRandomPhilox]
79+
@eval begin
80+
function $R(device)
81+
kernel = @objc [$R alloc]::id{$R}
82+
obj = $R(kernel)
83+
finalizer(release, obj)
84+
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}]::id{$R}
85+
return obj
86+
end
87+
function $R(device, destinationDataType, seed)
88+
kernel = @objc [$R alloc]::id{$R}
89+
obj = $R(kernel)
90+
finalizer(release, obj)
91+
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}
92+
destinationDataType:destinationDataType::MPSDataType
93+
seed:seed::NSUInteger]::id{$R}
94+
return obj
95+
end
96+
function $R(device, destinationDataType, seed, distributionDescriptor)
97+
kernel = @objc [$R alloc]::id{$R}
98+
obj = $R(kernel)
99+
finalizer(release, obj)
100+
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}
101+
destinationDataType:destinationDataType::MPSDataType
102+
seed:seed::NSUInteger
103+
distributionDescriptor:distributionDescriptor::id{MPSMatrixRandomDistributionDescriptor}]::id{$R}
104+
return obj
105+
end
106+
end
107+
end
108+
109+
synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) =
110+
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing
111+
112+
113+
@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2};
114+
queue::MTLCommandQueue = global_queue(device()),
115+
async::Bool=false) where {T,T2}
116+
byteoffset = dest.offset * sizeof(T)
117+
bytesize = sizeof(dest)
118+
119+
# Even though `append_copy`` seems to work with any size or offset values, the documentation at
120+
# https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc
121+
# mentions that both must be multiples of 4 bytes in MacOS so error when they are not
122+
(bytesize % 4 == 0) || error(lazy"Destination buffer bytesize ($(bytesize)) must be a multiple of 4.")
123+
(byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.")
124+
125+
cmdbuf = if bytesize % 16 == 0 && dest.offset == 0
126+
MTLCommandBuffer(queue) do cmdbuf
127+
vecDesc = MPSVectorDescriptor(bytesize ÷ sizeof(T2), T2)
128+
mpsdest = MPSVector(dest, vecDesc)
129+
encode!(cmdbuf, randkern, mpsdest)
130+
end
131+
else
132+
MTLCommandBuffer(queue) do cmdbuf
133+
len = UInt(ceil(bytesize / sizeof(T2)) * 4)
134+
vecDesc = MPSVectorDescriptor(len, T2)
135+
tempVec = MPSTemporaryVector(cmdbuf, vecDesc)
136+
encode!(cmdbuf, randkern, tempVec)
137+
MTLBlitCommandEncoder(cmdbuf) do enc
138+
MTL.append_copy!(enc, dest.data[], byteoffset, tempVec.data, tempVec.offset, bytesize)
139+
end
140+
end
141+
end
142+
143+
async || wait_completed(cmdbuf)
144+
return
145+
end

lib/mps/random.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using Random
2+
using Metal: DefaultStorageMode
3+
4+
"""
5+
MPS.RNG()
6+
7+
A random number generator using `rand()` in a device kernel.
8+
"""
9+
mutable struct RNG <: AbstractRNG
10+
device::MTLDevice
11+
uniformInteger::MPSMatrixRandomPhilox
12+
uniformFloat32::MPSMatrixRandomPhilox
13+
normalFloat32::MPSMatrixRandomPhilox
14+
end
15+
16+
17+
make_seed() = Base.rand(RandomDevice(), UInt)
18+
19+
function RNG(device::MTLDevice, seed::Integer)
20+
seed = seed%UInt
21+
RNG(device,
22+
MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()),
23+
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)),
24+
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),)
25+
end
26+
@autoreleasepool RNG(seed::Integer) = RNG(device(), seed)
27+
RNG(device::MTLDevice) = RNG(device, make_seed())
28+
29+
@autoreleasepool RNG() = RNG(device(), make_seed())
30+
31+
Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32))
32+
33+
@autoreleasepool function Random.seed!(rng::RNG, seed::Integer)
34+
rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor())
35+
rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1))
36+
rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1))
37+
return rng
38+
end
39+
40+
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())
41+
42+
const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}()
43+
@autoreleasepool function default_rng()
44+
dev = device()
45+
get!(GLOBAL_RNGs, dev) do
46+
RNG(dev)
47+
end
48+
end
49+
50+
const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64]
51+
const UniformType = Union{[Type{T} for T in UniformTypes]...}
52+
const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
53+
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
54+
isempty(A) && return A
55+
_mpsmat_rand!(rng.uniformInteger, A, UInt32)
56+
return A
57+
end
58+
59+
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32})
60+
isempty(A) && return A
61+
_mpsmat_rand!(rng.uniformFloat32, A, Float32)
62+
return A
63+
end
64+
65+
const NormalType = Type{Float32}
66+
const NormalArray = MtlArray{<:Float32}
67+
@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32})
68+
isempty(A) && return A
69+
_mpsmat_rand!(rng.normalFloat32, A, Float32)
70+
return A
71+
end
72+
73+
# CPU arrays
74+
function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N}
75+
isempty(A) && return A
76+
B = MtlArray{T,N,SharedStorage}(undef, size(A))
77+
rand!(rng, B)
78+
copyto!(A, unsafe_wrap(Array{T},B))
79+
return A
80+
end
81+
function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N}
82+
isempty(A) && return A
83+
B = MtlArray{T,N,SharedStorage}(undef, size(A))
84+
randn!(rng, B)
85+
copyto!(A, unsafe_wrap(Array{T},B))
86+
return A
87+
end
88+
89+
# Out of place
90+
Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) =
91+
Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
92+
Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) =
93+
Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
94+
95+
# support all dimension specifications
96+
Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
97+
Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
98+
Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
99+
Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
100+
101+
# untyped out-of-place
102+
Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
103+
Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
104+
Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
105+
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
106+
107+
# scalars
108+
Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 4; storage)[1]
109+
Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 4; storage)[1]

0 commit comments

Comments
 (0)