Skip to content

Commit 33ef834

Browse files
committed
Fix and test MPSGraph random
1 parent 87ca2e4 commit 33ef834

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

lib/mpsgraphs/random.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# @objcwrapper immutable=false MPSGraphRandomOpDescriptor <: MPSGraphObject
22

3-
function MPSMatrixRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType::MPSDataType)
4-
desc = @objc [MPSMatrixRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
3+
function MPSGraphRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType)
4+
desc = @objc [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
55
dataType:dataType::MPSDataType]::id{MPSGraphRandomOpDescriptor}
66
obj = MPSGraphRandomOpDescriptor(desc)
77
return obj

test/mpsgraphs/random.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using BFloat16s
2+
3+
if MPS.is_supported(device())
4+
5+
using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform
6+
@testset "MPSGraph random" begin
7+
# determined by looking at the error message when trying to construct
8+
# an invalid distribution/type combination
9+
for (dist, T) in [(MPSGraphRandomDistributionNormal, Float32),
10+
(MPSGraphRandomDistributionNormal, Float16),
11+
(MPSGraphRandomDistributionNormal, BFloat16),
12+
(MPSGraphRandomDistributionTruncatedNormal, Float32),
13+
(MPSGraphRandomDistributionTruncatedNormal, Float16),
14+
(MPSGraphRandomDistributionTruncatedNormal, BFloat16),
15+
(MPSGraphRandomDistributionUniform, Int64),
16+
(MPSGraphRandomDistributionUniform, Int32),
17+
(MPSGraphRandomDistributionUniform, Float32),
18+
(MPSGraphRandomDistributionUniform, Float16),
19+
(MPSGraphRandomDistributionUniform, BFloat16),
20+
]
21+
@test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor
22+
end
23+
end
24+
25+
end # MPS.is_supported(device())

0 commit comments

Comments
 (0)