Skip to content

Commit ab8785b

Browse files
committed
Guard against invalid ranks
1 parent 70e0307 commit ab8785b

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

lib/mps/ndarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ export MPSNDArrayDescriptor
77
# @objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject
88

99
function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr)
10+
1 <= dimensionCount <= 16 || throw(ArgumentError("`dimensionCount` must be between 1 and 16 inclusive"))
11+
1012
desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType
1113
dimensionCount:dimensionCount::NSUInteger
1214
dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor}

lib/mpsgraphs/tensor.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ function MPSGraphTensorData(matrix::MPSMatrix)
6363
return tensor
6464
end
6565

66-
# rank must be between 1 and 16 inclusive
6766
function MPSGraphTensorData(matrix::MPSMatrix, rank)
67+
1 <= rank <= 16 || throw(ArgumentError("`rank` must be between 1 and 16 inclusive"))
68+
6869
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
6970
tensor = MPSGraphTensorData(obj)
7071
finalizer(release, tensor)
@@ -81,8 +82,9 @@ function MPSGraphTensorData(vector::MPSVector)
8182
return tensor
8283
end
8384

84-
# rank must be between 1 and 16 inclusive
8585
function MPSGraphTensorData(vector::MPSVector, rank)
86+
1 <= rank <= 16 || throw(ArgumentError("`rank` must be between 1 and 16 inclusive"))
87+
8688
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
8789
tensor = MPSGraphTensorData(obj)
8890
finalizer(release, tensor)

test/mps/ndarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ end
2929
desc2.numberOfDimensions = 6
3030
@test desc2.numberOfDimensions == 6
3131

32+
@test_throws ArgumentError MPSNDArrayDescriptor(Float32, ones(Int, 17))
33+
3234
@static if Metal.macos_version() >= v"15"
3335
@test desc1.preferPackedRows == false
3436

0 commit comments

Comments
 (0)