Skip to content

Commit 84447c4

Browse files
Fix MPSNDArrayDescriptor wrapper (#502)
Don't reverse dimensions automatically
1 parent b9610e3 commit 84447c4

File tree

4 files changed

+37
-14
lines changed

4 files changed

+37
-14
lines changed

lib/mps/MPS.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import GPUArrays
1818

1919
const MtlFloat = Union{Float32, Float16}
2020

21+
const MPSShape = NSArray#{NSNumber}
22+
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
23+
2124
is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)
2225

2326
include("size.jl")

lib/mps/matrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ end
213213
"""
214214
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
215215
transpose_left=false, transpose_right=false)
216-
A `MPSMatrixMultiplication` kernel thay computes:
216+
A `MPSMatrixMultiplication` kernel that computes:
217217
`c = alpha * op(a) * beta * op(b) + beta * C`
218218
219219
This function should not typically be used. Rather, use the normal `LinearAlgebra` interface

lib/mps/ndarray.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes
2424
end
2525

2626
function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
27-
revshape = collect(reverse(shape))
28-
obj = GC.@preserve revshape begin
29-
shapeptr = pointer(revshape)
30-
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
27+
obj = GC.@preserve shape begin
28+
shapeptr = pointer(shape)
29+
MPSNDArrayDescriptor(dataType, length(shape), shapeptr)
3130
end
3231
return obj
3332
end
@@ -75,6 +74,11 @@ else
7574
end
7675
end
7776

77+
function Base.size(ndarr::MPSNDArray)
78+
ndims = Int(ndarr.numberOfDimensions)
79+
Tuple([Int(lengthOfDimension(ndarr,i)) for i in 0:ndims-1])
80+
end
81+
7882
@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray
7983

8084
@objcproperties MPSTemporaryNDArray begin
@@ -130,20 +134,23 @@ end
130134

131135
function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
132136
arrsize = size(arr)
133-
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
137+
@assert arrsize[1]*sizeof(T) % 16 == 0 "First dimension of arr must have a byte size divisible by 16"
134138
desc = MPSNDArrayDescriptor(T, arrsize)
135139
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
136140
end
137141

138142
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
139-
ndims = Int(ndarr.numberOfDimensions)
140-
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
143+
arrsize = size(ndarr)
141144
T = convert(DataType, ndarr.dataType)
142-
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
145+
arr = MtlArray{T,length(arrsize),storage}(undef, (arrsize)...)
146+
return exportToMtlArray!(arr, ndarr; async)
147+
end
148+
149+
function exportToMtlArray!(arr::MtlArray{T}, ndarr::MPSNDArray; async=false) where T
143150
dev = device(arr)
144151

145152
cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
146-
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
153+
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, arr.offset)
147154
end
148155

149156
async || wait_completed(cmdBuf)
@@ -157,6 +164,12 @@ exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffe
157164
destinationDataType:destinationDataType::MPSDataType
158165
offset:offset::NSUInteger
159166
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
167+
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset) =
168+
@objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
169+
toBuffer:toBuffer::id{MTLBuffer}
170+
destinationDataType:destinationDataType::MPSDataType
171+
offset:offset::NSUInteger
172+
rowStrides:nil::id{ObjectiveC.Object}]::Nothing
160173

161174
# rowStrides in Bytes
162175
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
@@ -165,6 +178,12 @@ importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBu
165178
sourceDataType:sourceDataType::MPSDataType
166179
offset:offset::NSUInteger
167180
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
181+
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset) =
182+
@objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
183+
fromBuffer:fromBuffer::id{MTLBuffer}
184+
sourceDataType:sourceDataType::MPSDataType
185+
offset:offset::NSUInteger
186+
rowStrides:nil::id{ObjectiveC.Object}]::Nothing
168187

169188
# TODO
170189
# exportDataWithCommandBuffer(toImages, offset)

test/mps/ndarray.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
77
T = Float32
88
DT = convert(MPSDataType, T)
99

10-
desc1 = MPSNDArrayDescriptor(T, 5,4,3,2,1)
10+
desc1 = MPSNDArrayDescriptor(T,1,2,3,4,5)
1111
@test desc1 isa MPSNDArrayDescriptor
1212
@test desc1.dataType == DT
1313
@test desc1.preferPackedRows == false
@@ -19,7 +19,7 @@ using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension
1919
@test lengthOfDimension(desc1,4) == 4
2020
@test lengthOfDimension(desc1,3) == 5
2121

22-
desc2 = MPSNDArrayDescriptor(T, (4,3,2,1))
22+
desc2 = MPSNDArrayDescriptor(T, (1,2,3,4))
2323
@test desc2 isa MPSNDArrayDescriptor
2424
@test desc2.dataType == DT
2525
@test desc2.numberOfDimensions == 4
@@ -51,6 +51,7 @@ using .MPS: MPSNDArray
5151
@test ndarr1.label == "Test1"
5252
@test ndarr1.numberOfDimensions == 5
5353
@test ndarr1.parent === nothing
54+
@test size(ndarr1) == (5,4,3,2,1)
5455

5556
ndarr2 = MPSNDArray(dev, 4)
5657
@test ndarr2 isa MPSNDArray
@@ -63,9 +64,9 @@ using .MPS: MPSNDArray
6364
@test ndarr2.parent === nothing
6465

6566
arr3 = MtlArray(ones(Float16, 2,3,4))
66-
@test_throws "Final dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)
67+
@test_throws "First dimension of arr must have a byte size divisible by 16" MPSNDArray(arr3)
6768

68-
arr4 = MtlArray(ones(Float16, 2,3,8))
69+
arr4 = MtlArray(ones(Float16, 8,3,2))
6970

7071
@static if Metal.macos_version() >= v"15"
7172
@test ndarr1.descriptor isa MPSNDArrayDescriptor

0 commit comments

Comments
 (0)