@@ -24,10 +24,9 @@ function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes
2424end
2525
2626function MPSNDArrayDescriptor (dataType:: DataType , shape:: DenseVector{T} ) where {T<: Union{Int,UInt} }
27- revshape = collect (reverse (shape))
28- obj = GC. @preserve revshape begin
29- shapeptr = pointer (revshape)
30- MPSNDArrayDescriptor (dataType, length (revshape), shapeptr)
27+ obj = GC. @preserve shape begin
28+ shapeptr = pointer (shape)
29+ MPSNDArrayDescriptor (dataType, length (shape), shapeptr)
3130 end
3231 return obj
3332end
7574 end
7675end
7776
77+ function Base. size (ndarr:: MPSNDArray )
78+ ndims = Int (ndarr. numberOfDimensions)
79+ Tuple ([Int (lengthOfDimension (ndarr,i)) for i in 0 : ndims- 1 ])
80+ end
81+
7882@objcwrapper immutable= false MPSTemporaryNDArray <: MPSNDArray
7983
8084@objcproperties MPSTemporaryNDArray begin
@@ -130,20 +134,23 @@ end
130134
131135function MPSNDArray (arr:: MtlArray{T,N} ) where {T,N}
132136 arrsize = size (arr)
133- @assert arrsize[end ]* sizeof (T) % 16 == 0 " Final dimension of arr must have a byte size divisible by 16"
137+ @assert arrsize[1 ]* sizeof (T) % 16 == 0 " First dimension of arr must have a byte size divisible by 16"
134138 desc = MPSNDArrayDescriptor (T, arrsize)
135139 return MPSNDArray (arr. data[], UInt (arr. offset), desc)
136140end
137141
138142function Metal. MtlArray (ndarr:: MPSNDArray ; storage = Metal. DefaultStorageMode, async = false )
139- ndims = Int (ndarr. numberOfDimensions)
140- arrsize = [lengthOfDimension (ndarr,i) for i in 0 : ndims- 1 ]
143+ arrsize = size (ndarr)
141144 T = convert (DataType, ndarr. dataType)
142- arr = MtlArray {T,ndims,storage} (undef, reverse (arrsize)... )
145+ arr = MtlArray {T,length(arrsize),storage} (undef, (arrsize). .. )
146+ return exportToMtlArray! (arr, ndarr; async)
147+ end
148+
149+ function exportToMtlArray! (arr:: MtlArray{T} , ndarr:: MPSNDArray ; async= false ) where T
143150 dev = device (arr)
144151
145152 cmdBuf = MTLCommandBuffer (global_queue (dev)) do cmdBuf
146- exportDataWithCommandBuffer (ndarr, cmdBuf, arr. data[], T, 0 , collect ( sizeof (T) .* reverse ( strides ( arr))) )
153+ exportDataWithCommandBuffer (ndarr, cmdBuf, arr. data[], T, arr. offset )
147154 end
148155
149156 async || wait_completed (cmdBuf)
@@ -157,6 +164,12 @@ exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffe
157164 destinationDataType: destinationDataType:: MPSDataType
158165 offset: offset:: NSUInteger
159166 rowStrides: pointer (rowStrides):: Ptr{NSInteger} ]:: Nothing
167+ exportDataWithCommandBuffer (ndarr:: MPSNDArray , cmdbuf:: MTLCommandBuffer , toBuffer, destinationDataType, offset) =
168+ @objc [ndarr:: MPSNDArray exportDataWithCommandBuffer: cmdbuf:: id{MTLCommandBuffer}
169+ toBuffer: toBuffer:: id{MTLBuffer}
170+ destinationDataType: destinationDataType:: MPSDataType
171+ offset: offset:: NSUInteger
172+ rowStrides: nil:: id{ObjectiveC.Object} ]:: Nothing
160173
161174# rowStrides in Bytes
162175importDataWithCommandBuffer! (ndarr:: MPSNDArray , cmdbuf:: MTLCommandBuffer , fromBuffer, sourceDataType, offset, rowStrides) =
@@ -165,6 +178,12 @@ importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBu
165178 sourceDataType: sourceDataType:: MPSDataType
166179 offset: offset:: NSUInteger
167180 rowStrides: pointer (rowStrides):: Ptr{NSInteger} ]:: Nothing
181+ importDataWithCommandBuffer! (ndarr:: MPSNDArray , cmdbuf:: MTLCommandBuffer , fromBuffer, sourceDataType, offset) =
182+ @objc [ndarr:: MPSNDArray importDataWithCommandBuffer: cmdbuf:: id{MTLCommandBuffer}
183+ fromBuffer: fromBuffer:: id{MTLBuffer}
184+ sourceDataType: sourceDataType:: MPSDataType
185+ offset: offset:: NSUInteger
186+ rowStrides: nil:: id{ObjectiveC.Object} ]:: Nothing
168187
169188# TODO
170189# exportDataWithCommandBuffer(toImages, offset)
0 commit comments