@@ -100,16 +100,17 @@ end
100100
101101# # argument conversion
102102
103- struct  Adaptor
103+ struct  Adaptor{T  <:   Union{Nothing,MTLComputeCommandEncoder,MTL4ArgumentTable} } 
104104    #  the current command encoder, if any.
105-     cce:: Union{Nothing,MTLComputeCommandEncoder} 
105+     cce:: T 
106106end 
107107
108108#  convert Metal buffers to their GPU address
109+ function  Adapt. adapt_storage (to:: Adaptor{<:MTLComputeCommandEncoder} , buf:: MTLBuffer )
110+     MTL. use! (to. cce, buf, MTL. ReadWriteUsage)
111+     reinterpret (Core. LLVMPtr{Nothing,AS. Device}, buf. gpuAddress)
112+ end 
109113function  Adapt. adapt_storage (to:: Adaptor , buf:: MTLBuffer )
110-     if  to. cce != =  nothing 
111-         MTL. use! (to. cce, buf, MTL. ReadWriteUsage)
112-     end 
113114    reinterpret (Core. LLVMPtr{Nothing,AS. Device}, buf. gpuAddress)
114115end 
115116function  Adapt. adapt_storage (to:: Adaptor , ptr:: MtlPtr{T} ) where  {T}
264265end 
265266
266267@autoreleasepool  function  (kernel:: HostKernel )(args... ; groups= 1 , threads= 1 ,
267-                                                queue= global_queue (device ()))
268+                                                queue= use_metal4 () ?  global_queue4 (device ()) :  global_queue (device ()))
269+     use_mtl4 =  queue isa  MTL4CommandQueue  
270+ 
268271    groups =  MTLSize (groups)
269272    threads =  MTLSize (threads)
270273    (groups. width> 0  &&  groups. height> 0  &&  groups. depth> 0 ) || 
@@ -275,16 +278,38 @@ end
275278    (threads. width *  threads. height *  threads. depth) >  kernel. pipeline. maxTotalThreadsPerThreadgroup && 
276279        throw (ArgumentError (" Number of threads in group ($(threads. width *  threads. height *  threads. depth) ) should not exceed $(kernel. pipeline. maxTotalThreadsPerThreadgroup) " 
277280
278-     cmdbuf =  MTLCommandBuffer (queue)
279-     cmdbuf. label =  " MTLCommandBuffer($(nameof (kernel. f)) )" 
280-     cce =  MTLComputeCommandEncoder (cmdbuf)
281+     if  use_mtl4
282+         allocator =  MTL4CommandAllocator (device ())
283+         cmdbuf =  MTL4CommandBuffer (device ())
284+ 
285+         #  TODO : Initialize with descriptor to set label
286+         #  cmdbuf.label = "MTL4CommandBuffer($(nameof(kernel.f)))"
287+ 
288+         beginCommandBufferWithAllocator! (cmdbuf, allocator)
289+         cce =  MTL4ComputeCommandEncoder (cmdbuf)
290+     else 
291+         cmdbuf =  MTLCommandBuffer (queue)
292+         cmdbuf. label =  " MTLCommandBuffer($(nameof (kernel. f)) )" 
293+         cce =  MTLComputeCommandEncoder (cmdbuf)
294+     end 
295+ 
281296    argument_buffers =  try 
282297        MTL. set_function! (cce, kernel. pipeline)
283-         bufs =  encode_arguments! (cce, kernel, kernel. f, args... )
298+         if  use_mtl4
299+             argtabdesc =  MTL. MTL4ArgumentTableDescriptor ()
300+             argtabdesc. maxBufferBindCount =  min (31 , length (args) +  1 )
301+             argtab =  MTL. MTL4ArgumentTable (device (), argtabdesc)
302+             bufs =  encode_arguments! (argtab, kernel, kernel. f, args... )
303+ 
304+             MTL. set_argument_table! (cce, argtab)
305+         else 
306+             bufs =  encode_arguments! (cce, kernel, kernel. f, args... )
307+         end 
284308        MTL. append_current_function! (cce, groups, threads)
285309        bufs
286310    finally 
287311        close (cce)
312+         use_mtl4 &&  endCommandBuffer! (cmdbuf)
288313    end 
289314
290315    #  the command buffer retains resources that are explicitly encoded (i.e. direct buffer
@@ -295,20 +320,26 @@ end
295320    #  kernel has actually completed.
296321    # 
297322    #  TODO : is there a way to bind additional resources to the command buffer?
298-     roots =  [kernel. f, args]
299-     MTL. on_completed (cmdbuf) do  buf
300-         empty! (roots)
301-         foreach (free, argument_buffers)
302- 
303-         #  Check for errors
304-         #  XXX : we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
305-         #       because we're not allowed to switch tasks from this contexts.
306-         if  buf. status ==  MTL. MTLCommandBufferStatusError
307-             Core. println (" ERROR: Failed to submit command buffer: $(buf. error. localizedDescription) " 
323+     if  ! use_mtl4
324+         roots =  [kernel. f, args]
325+         MTL. on_completed (cmdbuf) do  buf
326+             empty! (roots)
327+             foreach (free, argument_buffers)
328+ 
329+             #  Check for errors
330+             #  XXX : we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
331+             #       because we're not allowed to switch tasks from this contexts.
332+             if  buf. status ==  MTL. MTLCommandBufferStatusError
333+                 Core. println (" ERROR: Failed to submit command buffer: $(buf. error. localizedDescription) " 
334+             end 
308335        end 
336+     end 
309337
338+     if  use_mtl4
339+         commit! (queue, cmdbuf)
340+     else 
341+         commit! (cmdbuf)
310342    end 
311-     commit! (cmdbuf)
312343end 
313344
314345# # Intra-warp Helpers
0 commit comments