Skip to content

Commit 5056e33

Browse files
Initial support for MPSNDArray (#499)
1 parent 8c119cf commit 5056e33

File tree

4 files changed

+426
-1
lines changed

4 files changed

+426
-1
lines changed

lib/mps/MPS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("images.jl")
2929
include("matrix.jl")
3030
include("vector.jl")
3131
include("matrixrandom.jl")
32+
include("ndarray.jl")
3233
include("decomposition.jl")
3334
include("copy.jl")
3435

lib/mps/kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ end
1515
@objcwrapper MPSKernel <: NSObject
1616

1717
@objcproperties MPSKernel begin
18+
@autoproperty options::MPSKernelOptions setter=setOptions
1819
@autoproperty device::id{MTLDevice}
1920
@autoproperty label::id{NSString} setter=setLabel
20-
@autoproperty options::MPSKernelOptions setter=setOptions
2121
end
2222

2323
@autoreleasepool function Base.copy(kernel::K) where {K <: MPSKernel}

lib/mps/ndarray.jl

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
#
2+
# ndarray descriptor
3+
#
4+
5+
export MPSNDArrayDescriptor
6+
7+
@objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject
8+
9+
@objcproperties MPSNDArrayDescriptor begin
10+
@autoproperty dataType::MPSDataType setter=setDataType
11+
@autoproperty numberOfDimensions::NSUInteger setter=setNumberOfDimensions
12+
13+
# Both are officially available starting macOS 15, but they work in macOS 13/14
14+
@autoproperty preferPackedRows::Bool setter=setPreferPackedRows # macOS 15+
15+
@autoproperty getShape::id{NSArray} # macOS 15+
16+
end
17+
18+
function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr)
19+
desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType
20+
dimensionCount:dimensionCount::NSUInteger
21+
dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor}
22+
obj = MPSNDArrayDescriptor(desc)
23+
return obj
24+
end
25+
26+
function MPSNDArrayDescriptor(dataType::DataType, shape::DenseVector{T}) where {T<:Union{Int,UInt}}
27+
revshape = collect(reverse(shape))
28+
obj = GC.@preserve revshape begin
29+
shapeptr = pointer(revshape)
30+
MPSNDArrayDescriptor(dataType, length(revshape), shapeptr)
31+
end
32+
return obj
33+
end
34+
MPSNDArrayDescriptor(dataType::DataType, shape::Tuple) = MPSNDArrayDescriptor(dataType, collect(shape))
35+
36+
MPSNDArrayDescriptor(dataType::DataType, dimensionSizes...) = @inline MPSNDArrayDescriptor(dataType, collect(dimensionSizes))
37+
38+
lengthOfDimension(desc::MPSNDArrayDescriptor, dim) = @objc [desc::id{MPSNDArrayDescriptor} lengthOfDimension:dim::UInt]::UInt
39+
40+
function transposeDimensionwithDimension(desc::MPSNDArrayDescriptor, dim1, dim2)
41+
@objc [desc::id{MPSNDArrayDescriptor} transposeDimension:dim1::UInt
42+
withDimension:dim2::UInt]::Cvoid
43+
end
44+
45+
#
46+
# ndarray object
47+
#
48+
49+
export MPSNDArray
50+
51+
@objcwrapper immutable=false MPSNDArray <: NSObject
52+
53+
@static if Metal.macos_version() >= v"15"
54+
@objcproperties MPSNDArray begin
55+
@autoproperty dataType::MPSDataType
56+
@autoproperty dataTypeSize::Csize_t
57+
@autoproperty device::id{MTLDevice}
58+
@autoproperty label::id{NSString} setter=setLabel
59+
@autoproperty numberOfDimensions::NSUInteger
60+
@autoproperty parent::id{MPSNDArray}
61+
62+
#Instance methods that act like properties
63+
@autoproperty descriptor::id{MPSNDArrayDescriptor}
64+
@autoproperty resourceSize::NSUInteger
65+
@autoproperty userBuffer::id{MTLBuffer}
66+
end
67+
else
68+
@objcproperties MPSNDArray begin
69+
@autoproperty dataType::MPSDataType
70+
@autoproperty dataTypeSize::Csize_t
71+
@autoproperty device::id{MTLDevice}
72+
@autoproperty label::id{NSString} setter=setLabel
73+
@autoproperty numberOfDimensions::NSUInteger
74+
@autoproperty parent::id{MPSNDArray}
75+
end
76+
end
77+
78+
@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray
79+
80+
@objcproperties MPSTemporaryNDArray begin
81+
@autoproperty readCount::NSUInteger setter=setReadCount
82+
end
83+
84+
function MPSTemporaryNDArray(cmdbuf::MTLCommandBuffer, descriptor::MPSNDArrayDescriptor)
85+
@objc [MPSNDTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
86+
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSTemporaryNDArray}
87+
return obj
88+
end
89+
90+
"""
91+
MPSNDArray([device::MTLDevice], arr::MtlArray)
92+
93+
Metal ndarray representation used in Performance Shaders.
94+
95+
May not contain more than 16 dimensions.
96+
"""
97+
function MPSNDArray(device::MTLDevice, desc::MPSNDArrayDescriptor)
98+
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
99+
obj = MPSNDArray(arrayaddr)
100+
finalizer(release, obj)
101+
@objc [obj::MPSNDArray initWithDevice:device::id{MTLDevice}
102+
descriptor:desc::id{MPSNDArrayDescriptor}]::id{MPSNDArray}
103+
return obj
104+
end
105+
106+
function MPSNDArray(device::MTLDevice, scalar)
107+
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
108+
obj = MPSNDArray(arrayaddr)
109+
finalizer(release, obj)
110+
@objc [obj::MPSNDArray initWithDevice:device::id{MTLDevice}
111+
scalar:scalar::Float64]::id{MPSNDArray}
112+
return obj
113+
end
114+
115+
@static if Metal.macos_version() >= v"15"
116+
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
117+
arrayaddr = @objc [MPSNDArray alloc]::id{MPSNDArray}
118+
obj = MPSNDArray(arrayaddr)
119+
finalizer(release, obj)
120+
@objc [obj::MPSNDArray initWithBuffer:buffer::id{MTLBuffer}
121+
offset:offset::NSUInteger
122+
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSNDArray}
123+
return obj
124+
end
125+
else
126+
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
127+
@assert false "Creating an MPSNDArray that shares data with user-provided MTLBuffer is only supported in macOS v15+"
128+
end
129+
end
130+
131+
function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
132+
arrsize = size(arr)
133+
@assert arrsize[end]*sizeof(T) % 16 == 0 "Final dimension of arr must have a byte size divisible by 16"
134+
desc = MPSNDArrayDescriptor(T, arrsize)
135+
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
136+
end
137+
138+
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode)
139+
ndims = Int(ndarr.numberOfDimensions)
140+
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
141+
T = convert(DataType, ndarr.dataType)
142+
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
143+
dev = device(arr)
144+
145+
cmdBuf = MTLCommandBuffer(global_queue(dev))
146+
147+
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
148+
149+
commit!(cmdBuf)
150+
wait_completed(cmdBuf)
151+
152+
return arr
153+
end
154+
155+
# rowStrides in Bytes
156+
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset, rowStrides) =
157+
GC.@preserve rowStrides @objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
158+
toBuffer:toBuffer::id{MTLBuffer}
159+
destinationDataType:destinationDataType::MPSDataType
160+
offset:offset::NSUInteger
161+
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
162+
163+
# rowStrides in Bytes
164+
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
165+
GC.@preserve rowStrides @objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
166+
fromBuffer:fromBuffer::id{MTLBuffer}
167+
sourceDataType:sourceDataType::MPSDataType
168+
offset:offset::NSUInteger
169+
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
170+
171+
# TODO
172+
# exportDataWithCommandBuffer(toImages, offset)
173+
# importDataWithCommandBuffer(fromImages, offset)
174+
175+
# 0-indexed
176+
lengthOfDimension(ndarr::MPSNDArray, dimensionIndex) =
177+
@objc [ndarr::MPSNDArray lengthOfDimension:dimensionIndex::NSUInteger]::UInt
178+
179+
# TODO
180+
# readBytes(strideBytes)
181+
# writeBytes(strideBytes)
182+
183+
synchronizeOnCommandBuffer(ndarr::MPSNDArray, q::MTLCommandBuffer) =
184+
@objc [ndarr::MPSNDArray synchronizeOnCommandBuffer:q::id{MTLCommandBuffer}]::Nothing
185+
186+
187+
export MPSNDArrayMultiaryBase
188+
189+
@objcwrapper immutable=false MPSNDArrayMultiaryBase <: MPSKernel
190+
191+
export MPSNDArrayMultiaryKernel
192+
193+
@objcwrapper immutable=false MPSNDArrayMultiaryKernel <: MPSNDArrayMultiaryBase
194+
195+
function MPSNDArrayMultiaryKernel(device, sourceCount)
196+
kernel = @objc [MPSNDArrayMultiaryKernel alloc]::id{MPSNDArrayMultiaryKernel}
197+
obj = MPSNDArrayMultiaryKernel(kernel)
198+
finalizer(release, obj)
199+
@objc [obj::id{MPSNDArrayMultiaryKernel} initWithDevice:device::id{MTLDevice}
200+
sourceCount:sourceCount::NSUInteger]::id{MPSNDArrayMultiaryKernel}
201+
return obj
202+
end
203+
204+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays) where {K<:MPSNDArrayMultiaryKernel}
205+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
206+
sourceArrays:sourceArrays::id{NSArray}]::id{MPSNDArray}
207+
end
208+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, destinationArray) where {K<:MPSNDArrayMultiaryKernel}
209+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
210+
sourceArrays:sourceArrays::id{NSArray}
211+
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
212+
end
213+
# TODO: MPSState is not implemented yet, so these don't work
214+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, resultState, destinationArray) where {K<:MPSNDArrayMultiaryKernel}
215+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
216+
# sourceArrays:sourceArrays::id{NSArray}
217+
# resultState:resultState::id{MPSState}
218+
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
219+
# end
220+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArrays, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayMultiaryKernel}
221+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
222+
# sourceArrays:sourceArrays::id{NSArray}
223+
# resultState:resultState::id{MPSState}
224+
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
225+
# end
226+
227+
export MPSNDArrayUnaryKernel
228+
229+
@objcwrapper immutable=false MPSNDArrayUnaryKernel <: MPSNDArrayMultiaryBase
230+
231+
function MPSNDArrayUnaryKernel(device)
232+
kernel = @objc [MPSNDArrayUnaryKernel alloc]::id{MPSNDArrayUnaryKernel}
233+
obj = MPSNDArrayUnaryKernel(kernel)
234+
finalizer(release, obj)
235+
@objc [obj::id{MPSNDArrayUnaryKernel} initWithDevice:device::id{MTLDevice}]::id{MPSNDArrayUnaryKernel}
236+
return obj
237+
end
238+
239+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray) where {K<:MPSNDArrayUnaryKernel}
240+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
241+
sourceArray:sourceArray::id{MPSNDArray}]::id{MPSNDArray}
242+
end
243+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, destinationArray) where {K<:MPSNDArrayUnaryKernel}
244+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
245+
sourceArray:sourceArray::id{MPSNDArray}
246+
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
247+
end
248+
# TODO: MPSState is not implemented yet, so these don't work
249+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, resultState, destinationArray) where {K<:MPSNDArrayUnaryKernel}
250+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
251+
# sourceArray:sourceArray::id{MPSNDArray}
252+
# resultState:resultState::id{MPSState}
253+
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
254+
# end
255+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, sourceArray, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayUnaryKernel}
256+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
257+
# sourceArray:sourceArrays::id{MPSNDArray}
258+
# resultState:resultState::id{MPSState}
259+
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
260+
# end
261+
262+
export MPSNDArrayBinaryKernel
263+
264+
@objcwrapper immutable=false MPSNDArrayBinaryKernel <: MPSNDArrayMultiaryBase
265+
266+
function MPSNDArrayBinaryKernel(device)
267+
kernel = @objc [MPSNDArrayBinaryKernel alloc]::id{MPSNDArrayBinaryKernel}
268+
obj = MPSNDArrayBinaryKernel(kernel)
269+
finalizer(release, obj)
270+
@objc [obj::id{MPSNDArrayBinaryKernel} initWithDevice:device::id{MTLDevice}]::id{MPSNDArrayBinaryKernel}
271+
return obj
272+
end
273+
274+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray) where {K<:MPSNDArrayBinaryKernel}
275+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
276+
secondarySourceArray:secondarySourceArray::id{MPSNDArray}
277+
primarySourceArray:primarySourceArray::id{MPSNDArray}]::id{MPSNDArray}
278+
end
279+
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, destinationArray) where {K<:MPSNDArrayBinaryKernel}
280+
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
281+
primarySourceArray:primarySourceArray::id{MPSNDArray}
282+
secondarySourceArray:secondarySourceArray::id{MPSNDArray}
283+
destinationArray:destinationArray::id{MPSNDArray}]::Nothing
284+
end
285+
# TODO: MPSState is not implemented yet, so these don't work
286+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, resultState, destinationArray) where {K<:MPSNDArrayBinaryKernel}
287+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
288+
# primarySourceArray:primarySourceArray::id{MPSNDArray}
289+
# secondarySourceArray:secondarySourceArray::id{MPSNDArray}
290+
# resultState:resultState::id{MPSState}
291+
# destinationArray:destinationArray::id{MPSNDArray}]::Nothing
292+
# end
293+
# function encode!(cmdbuf::MTLCommandBuffer, kernel::K, primarySourceArray, secondarySourceArray, resultState, outputStateIsTemporary::Bool) where {K<:MPSNDArrayBinaryKernel}
294+
# @objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
295+
# primarySourceArray:primarySourceArrays::id{MPSNDArray}
296+
# secondarySourceArray:secondarySourceArray::id{MPSNDArray}
297+
# resultState:resultState::id{MPSState}
298+
# outputStateIsTemporary:outputStateIsTemporary::Bool]::MPSNDArray
299+
# end
300+
301+
@objcwrapper immutable=false MPSNDArrayMatrixMultiplication <: MPSNDArrayMultiaryKernel
302+
303+
@objcproperties MPSNDArrayMatrixMultiplication begin
304+
@autoproperty alpha::Float64 setter=setAlpha
305+
@autoproperty beta::Float64 setter=setBeta
306+
end
307+
308+
function MPSNDArrayMatrixMultiplication(device, sourceCount)
309+
kernel = @objc [MPSNDArrayMatrixMultiplication alloc]::id{MPSNDArrayMatrixMultiplication}
310+
obj = MPSNDArrayMatrixMultiplication(kernel)
311+
finalizer(release, obj)
312+
@objc [obj::id{MPSNDArrayMatrixMultiplication} initWithDevice:device::id{MTLDevice}
313+
sourceCount:sourceCount::NSUInteger]::id{MPSNDArrayMatrixMultiplication}
314+
return obj
315+
end

0 commit comments

Comments
 (0)