Skip to content

Commit b6062f3

Browse files
Enable MPS rand for more array sizes (#677)
1 parent 04ec7eb commit b6062f3

File tree

4 files changed

+31
-68
lines changed

4 files changed

+31
-68
lines changed

lib/mps/matrixrandom.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,6 @@ synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) =
9393
byteoffset = dest.offset * sizeof(T)
9494
bytesize = sizeof(dest)
9595

96-
# Even though `append_copy`` seems to work with any size or offset values, the documentation at
97-
# https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc
98-
# mentions that both must be multiples of 4 bytes in MacOS so error when they are not
99-
(bytesize % 4 == 0) || error(lazy"Destination buffer bytesize ($(bytesize)) must be a multiple of 4.")
100-
(byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.")
101-
10296
cmdbuf = if bytesize % 16 == 0 && dest.offset == 0
10397
MTLCommandBuffer(queue) do cmdbuf
10498
vecDesc = MPSVectorDescriptor(bytesize ÷ sizeof(T2), T2)

lib/mps/random.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,5 @@ Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMo
116116
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
117117

118118
# scalars
119-
Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 4; storage)[1]
120-
Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 4; storage)[1]
119+
Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 1; storage)[1]
120+
Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 1; storage)[1]

src/random.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,20 @@ mpsrand_rng() = MPS.default_rng()
99
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
1010
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)
1111

12-
@inline function can_use_mpsrandom(A::MtlArray{T}) where {T}
13-
return A.offset * sizeof(T) % 4 == 0 && sizeof(A) % 4 == 0
14-
end
15-
1612
# Use MPS random functionality where possible
1713
function Random.rand!(A::MPS.UniformArray)
18-
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
19-
return Random.rand!(rng, A)
14+
return Random.rand!(mpsrand_rng(), A)
2015
end
2116
function Random.randn!(A::MPS.NormalArray)
22-
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
23-
return Random.randn!(rng, A)
17+
return Random.randn!(mpsrand_rng(), A)
2418
end
2519

2620
# GPUArrays out-of-place
2721
function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode)
28-
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
29-
return Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
22+
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
3023
end
3124
function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
32-
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
33-
return Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
25+
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
3426
end
3527
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
3628
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
@@ -39,12 +31,10 @@ randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
3931

4032
# support all dimension specifications
4133
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
42-
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
43-
return Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
34+
return Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
4435
end
4536
function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
46-
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
47-
return Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
37+
return Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
4838
end
4939

5040
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
@@ -59,8 +49,8 @@ randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
5949
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
6050

6151
# scalars
62-
rand(T::Type=Float32; storage=SharedStorage) = rand(T, 4; storage)[1]
63-
randn(T::Type=Float32; storage=SharedStorage) = randn(T, 4; storage)[1]
52+
rand(T::Type=Float32; storage=SharedStorage) = rand(T, 1; storage)[1]
53+
randn(T::Type=Float32; storage=SharedStorage) = randn(T, 1; storage)[1]
6454

6555
# seeding
6656
function seed!(seed=Base.rand(UInt64))

test/random.jl

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
2424
# specified MPS rng
2525
if T != Float16
2626
fill!(A, T(0))
27-
if Metal.can_use_mpsrandom(A)
28-
f(rng, A)
29-
@test !iszero(collect(A))
30-
else
31-
@test_throws "Destination buffer" f(rng, A)
32-
end
27+
28+
f(rng, A)
29+
@test !iszero(collect(A))
3330
end
3431
end
3532

@@ -44,12 +41,9 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
4441
# specified MPS rng
4542
if T != Float16
4643
fill!(A, T(0))
47-
if Metal.can_use_mpsrandom(A)
48-
f(rng, A)
49-
@test Array(A) == fill(1, 0)
50-
else
51-
@test_throws "Destination buffer" f(rng, A)
52-
end
44+
45+
f(rng, A)
46+
@test Array(A) == fill(1, 0)
5347
end
5448
end
5549
end
@@ -131,30 +125,22 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
131125
idx = 4:50
132126
view_A = @view A[idx]
133127

134-
# Errors in Julia before crashing whole process
135-
if Metal.can_use_mpsrandom(view_A)
136-
f(rng, view_A)
128+
f(rng, view_A)
137129

138-
cpuA = collect(A)
139-
@test !iszero(cpuA[idx])
140-
@test iszero(cpuA[1:100 .∉ Ref(idx)]) broken=(sizeof(view_A) % 4 != 0)
141-
else
142-
@test_throws "Destination buffer" f(rng, view_A)
143-
end
130+
cpuA = collect(A)
131+
@test !iszero(cpuA[idx])
132+
@test iszero(cpuA[1:100 .∉ Ref(idx)])
144133

145134
## Offset == 0
146135
fill!(A, T(0))
147136
idx = 1:51
148137
view_A = @view A[idx]
149-
if Metal.can_use_mpsrandom(view_A)
150-
f(rng, view_A)
151-
152-
cpuA = collect(A)
153-
@test !iszero(cpuA[idx])
154-
@test iszero(cpuA[1:100 .∉ Ref(idx)])
155-
else
156-
@test_throws "Destination buffer" f(rng, view_A)
157-
end
138+
139+
f(rng, view_A)
140+
141+
cpuA = collect(A)
142+
@test !iszero(cpuA[idx])
143+
@test iszero(cpuA[1:100 .∉ Ref(idx)])
158144
end
159145
end
160146
end
@@ -200,12 +186,8 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
200186

201187
# specified MPS rng
202188
if T != Float16
203-
if length(zeros(args...)) * sizeof(T) % 4 == 0
204-
B = fr(rng, args...)
205-
@test eltype(B) == T
206-
else
207-
@test_throws "Destination buffer" fr(rng, args...)
208-
end
189+
B = fr(rng, args...)
190+
@test eltype(B) == T
209191
end
210192
end
211193

@@ -228,12 +210,9 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
228210
# d == 2 and d == 3 are to hit the test cases where sizeof(A) <= 4
229211
@testset "$d" for d in (2, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000), 16384, 16385)
230212
A = zeros(T, d)
231-
if (prod(d) * sizeof(T)) % 4 == 0
232-
f(rng, A)
233-
@test !iszero(A)
234-
else
235-
@test_throws "Destination buffer" f(rng, A)
236-
end
213+
214+
f(rng, A)
215+
@test !iszero(A)
237216
end
238217
end
239218
end

0 commit comments

Comments
 (0)