@@ -207,10 +207,9 @@ const _kernel_instances = Dict{UInt, Any}()
207207
208208# # kernel launching and argument encoding
209209
210- @inline @generated function encode_arguments! (cce, kernel, args... )
211- ex = quote
212- bufs = MTLBuffer[]
213- end
210+ @inline @generated function encode_arguments! (cce, kernel, args:: Vararg{Any,N} ) where {N}
211+ ex = quote end
212+ buffers = []
214213
215214 # the arguments passed into this function have not been `mtlconvert`ed, because we need
216215 # to retain the top-level MTLBuffer and MtlPtr objects. eager conversion of nested
@@ -230,17 +229,18 @@ const _kernel_instances = Dict{UInt, Any}()
230229 continue
231230 else
232231 # everything else is passed by reference, in an argument buffer
232+ buf = gensym (" buffer" )
233233 append! (ex. args, (quote
234- buf = encode_argument! (kernel, mtlconvert ($ (argex), cce))
235- set_buffer! (cce, buf, 0 , $ idx)
236- push! (bufs, buf)
234+ $ buf = encode_argument! (kernel, mtlconvert ($ (argex), cce))
235+ set_buffer! (cce, $ buf, 0 , $ idx)
237236 end ). args)
237+ push! (buffers, buf)
238238 end
239239 idx += 1
240240 end
241241
242242 append! (ex. args, (quote
243- return bufs
243+ return ( $ (buffers ... ),)
244244 end ). args)
245245
246246 ex
275275 (threads. width * threads. height * threads. depth) > kernel. pipeline. maxTotalThreadsPerThreadgroup &&
276276 throw (ArgumentError (" Number of threads in group ($(threads. width * threads. height * threads. depth) ) should not exceed $(kernel. pipeline. maxTotalThreadsPerThreadgroup) " ))
277277
278- kernel_state = KernelState (rand (UInt32))
278+ kernel_state = KernelState (Random . rand (UInt32))
279279
280280 cmdbuf = MTLCommandBuffer (queue)
281281 cmdbuf. label = " MTLCommandBuffer($(nameof (kernel. f)) )"
297297 # kernel has actually completed.
298298 #
299299 # TODO : is there a way to bind additional resources to the command buffer?
300- roots = [kernel. f, kernel_state, args]
300+ roots = [kernel. f, args]
301301 MTL. on_completed (cmdbuf) do buf
302302 empty! (roots)
303303 foreach (free, argument_buffers)
308308 if buf. status == MTL. MTLCommandBufferStatusError
309309 Core. println (" ERROR: Failed to submit command buffer: $(buf. error. localizedDescription) " )
310310 end
311-
312311 end
312+
313313 commit! (cmdbuf)
314314end
315315
0 commit comments