Skip to content

Commit aad4eab

Browse files
committed
(Not working) Use Metal 4
1 parent c7ce089 commit aad4eab

File tree

1 file changed

+52
-21
lines changed

1 file changed

+52
-21
lines changed

src/compiler/execution.jl

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,17 @@ end
100100

101101
## argument conversion
102102

103-
struct Adaptor
103+
struct Adaptor{T <: Union{Nothing,MTLComputeCommandEncoder,MTL4ArgumentTable}}
104104
# the current command encoder, if any.
105-
cce::Union{Nothing,MTLComputeCommandEncoder}
105+
cce::T
106106
end
107107

108108
# convert Metal buffers to their GPU address
109+
function Adapt.adapt_storage(to::Adaptor{<:MTLComputeCommandEncoder}, buf::MTLBuffer)
110+
MTL.use!(to.cce, buf, MTL.ReadWriteUsage)
111+
reinterpret(Core.LLVMPtr{Nothing,AS.Device}, buf.gpuAddress)
112+
end
109113
function Adapt.adapt_storage(to::Adaptor, buf::MTLBuffer)
110-
if to.cce !== nothing
111-
MTL.use!(to.cce, buf, MTL.ReadWriteUsage)
112-
end
113114
reinterpret(Core.LLVMPtr{Nothing,AS.Device}, buf.gpuAddress)
114115
end
115116
function Adapt.adapt_storage(to::Adaptor, ptr::MtlPtr{T}) where {T}
@@ -264,7 +265,9 @@ end
264265
end
265266

266267
@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1,
267-
queue=global_queue(device()))
268+
queue=use_metal4() ? global_queue4(device()) : global_queue(device()))
269+
use_mtl4 = queue isa MTL4CommandQueue
270+
268271
groups = MTLSize(groups)
269272
threads = MTLSize(threads)
270273
(groups.width>0 && groups.height>0 && groups.depth>0) ||
@@ -275,16 +278,38 @@ end
275278
(threads.width * threads.height * threads.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup &&
276279
throw(ArgumentError("Number of threads in group ($(threads.width * threads.height * threads.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)"))
277280

278-
cmdbuf = MTLCommandBuffer(queue)
279-
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
280-
cce = MTLComputeCommandEncoder(cmdbuf)
281+
if use_mtl4
282+
allocator = MTL4CommandAllocator(device())
283+
cmdbuf = MTL4CommandBuffer(device())
284+
285+
# TODO: Initialize with descriptor to set label
286+
# cmdbuf.label = "MTL4CommandBuffer($(nameof(kernel.f)))"
287+
288+
beginCommandBufferWithAllocator!(cmdbuf, allocator)
289+
cce = MTL4ComputeCommandEncoder(cmdbuf)
290+
else
291+
cmdbuf = MTLCommandBuffer(queue)
292+
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
293+
cce = MTLComputeCommandEncoder(cmdbuf)
294+
end
295+
281296
argument_buffers = try
282297
MTL.set_function!(cce, kernel.pipeline)
283-
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
298+
if use_mtl4
299+
argtabdesc = MTL.MTL4ArgumentTableDescriptor()
300+
argtabdesc.maxBufferBindCount = min(31, length(args) + 1)
301+
argtab = MTL.MTL4ArgumentTable(device(), argtabdesc)
302+
bufs = encode_arguments!(argtab, kernel, kernel.f, args...)
303+
304+
MTL.set_argument_table!(cce, argtab)
305+
else
306+
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
307+
end
284308
MTL.append_current_function!(cce, groups, threads)
285309
bufs
286310
finally
287311
close(cce)
312+
use_mtl4 && endCommandBuffer!(cmdbuf)
288313
end
289314

290315
# the command buffer retains resources that are explicitly encoded (i.e. direct buffer
@@ -295,20 +320,26 @@ end
295320
# kernel has actually completed.
296321
#
297322
# TODO: is there a way to bind additional resources to the command buffer?
298-
roots = [kernel.f, args]
299-
MTL.on_completed(cmdbuf) do buf
300-
empty!(roots)
301-
foreach(free, argument_buffers)
302-
303-
# Check for errors
304-
# XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
305-
# because we're not allowed to switch tasks from this contexts.
306-
if buf.status == MTL.MTLCommandBufferStatusError
307-
Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)")
323+
if !use_mtl4
324+
roots = [kernel.f, args]
325+
MTL.on_completed(cmdbuf) do buf
326+
empty!(roots)
327+
foreach(free, argument_buffers)
328+
329+
# Check for errors
330+
# XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
331+
# because we're not allowed to switch tasks from this contexts.
332+
if buf.status == MTL.MTLCommandBufferStatusError
333+
Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)")
334+
end
308335
end
336+
end
309337

338+
if use_mtl4
339+
commit!(queue, cmdbuf)
340+
else
341+
commit!(cmdbuf)
310342
end
311-
commit!(cmdbuf)
312343
end
313344

314345
## Intra-warp Helpers

0 commit comments

Comments
 (0)