2424};
2525
2626struct ggml_metal {
27- id <MTLDevice > device;
28- id <MTLCommandQueue > queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
29-
3027 ggml_metal_device_t dev;
3128 ggml_metal_library_t lib;
3229
@@ -91,15 +88,15 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
9188 // init context
9289 ggml_metal_t res = calloc (1 , sizeof (struct ggml_metal));
9390
94- res-> device = ggml_metal_device_get_obj (dev);
91+ id < MTLDevice > device = ggml_metal_device_get_obj (dev);
9592
96- GGML_LOG_INFO (" %s : picking default device: %s \n " , __func__, [[res-> device name ] UTF8String ]);
93+ GGML_LOG_INFO (" %s : picking default device: %s \n " , __func__, [[device name ] UTF8String ]);
9794
9895 // TODO: would it be better to have one queue for the backend and one queue for the device?
9996 // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
10097 // res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
101- res-> queue = ggml_metal_device_get_queue (dev);
102- if (res-> queue == nil ) {
98+ id < MTLCommandQueue > queue = ggml_metal_device_get_queue (dev);
99+ if (queue == nil ) {
103100 GGML_LOG_ERROR (" %s : error: failed to create command queue\n " , __func__);
104101 return NULL ;
105102 }
@@ -274,7 +271,8 @@ static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_te
274271void ggml_metal_set_tensor_async (ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
275272 @autoreleasepool {
276273 // wrap the source data into a Metal buffer
277- id <MTLBuffer > buf_src = [ctx->device newBufferWithBytes: data
274+ id <MTLDevice > device = ggml_metal_device_get_obj (ctx->dev );
275+ id <MTLBuffer > buf_src = [device newBufferWithBytes: data
278276 length: size
279277 options: MTLResourceStorageModeShared ];
280278
@@ -289,7 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
289287
290288 // queue the copy operation into the queue of the Metal context
291289 // this will be queued at the end, after any currently ongoing GPU operations
292- id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBuffer ];
290+ id <MTLCommandQueue > queue = ggml_metal_device_get_queue (ctx->dev );
291+ id <MTLCommandBuffer > cmd_buf = [queue commandBuffer ];
293292 id <MTLBlitCommandEncoder > encoder = [cmd_buf blitCommandEncoder ];
294293
295294 [encoder copyFromBuffer: buf_src
@@ -315,7 +314,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
315314
316315void ggml_metal_get_tensor_async (ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
317316 @autoreleasepool {
318- id <MTLBuffer > buf_dst = [ctx->device newBufferWithBytesNoCopy: data
317+ id <MTLDevice > device = ggml_metal_device_get_obj (ctx->dev );
318+ id <MTLBuffer > buf_dst = [device newBufferWithBytesNoCopy: data
319319 length: size
320320 options: MTLResourceStorageModeShared
321321 deallocator: nil ];
@@ -331,7 +331,8 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
331331
332332 // queue the copy operation into the queue of the Metal context
333333 // this will be queued at the end, after any currently ongoing GPU operations
334- id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBuffer ];
334+ id <MTLCommandQueue > queue = ggml_metal_device_get_queue (ctx->dev );
335+ id <MTLCommandBuffer > cmd_buf = [queue commandBuffer ];
335336 id <MTLBlitCommandEncoder > encoder = [cmd_buf blitCommandEncoder ];
336337
337338 [encoder copyFromBuffer: bid_src.metal
@@ -362,6 +363,9 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
362363 // number of threads in addition to the main thread
363364 const int n_cb = ctx->n_cb ;
364365
366+ // keep the memory wired
367+ ggml_metal_device_rsets_keep_alive (ctx->dev );
368+
365369 // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
366370 // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
367371 // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
@@ -389,7 +393,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
389393
390394 if (!ctx->capture_started ) {
391395 // create capture scope
392- ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager ] newCaptureScopeWithDevice: ctx->device];
396+ id <MTLDevice > device = ggml_metal_device_get_obj (ctx->dev );
397+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager ] newCaptureScopeWithDevice: device];
393398
394399 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new ];
395400 descriptor.captureObject = ctx->capture_scope ;
@@ -406,10 +411,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
406411 }
407412 }
408413
414+ // short-hand
415+ id <MTLCommandQueue > queue = ggml_metal_device_get_queue (ctx->dev );
416+
409417 // the main thread commits the first few commands immediately
410418 // cmd_buf[n_cb]
411419 {
412- id <MTLCommandBuffer > cmd_buf = [ctx-> queue commandBufferWithUnretainedReferences ];
420+ id <MTLCommandBuffer > cmd_buf = [queue commandBufferWithUnretainedReferences ];
413421 [cmd_buf retain ];
414422
415423 if (ctx->cmd_bufs [n_cb].obj ) {
@@ -428,7 +436,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
428436 // prepare the rest of the command buffers asynchronously (optional)
429437 // cmd_buf[0.. n_cb)
430438 for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
431- id <MTLCommandBuffer > cmd_buf = [ctx-> queue commandBufferWithUnretainedReferences ];
439+ id <MTLCommandBuffer > cmd_buf = [queue commandBufferWithUnretainedReferences ];
432440 [cmd_buf retain ];
433441
434442 if (ctx->cmd_bufs [cb_idx].obj ) {
@@ -589,9 +597,11 @@ void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_c
589597}
590598
591599bool ggml_metal_supports_family (ggml_metal_t ctx, int family) {
592- GGML_ASSERT (ctx->device != nil );
600+ GGML_ASSERT (ctx->dev != nil );
601+
602+ id <MTLDevice > device = ggml_metal_device_get_obj (ctx->dev );
593603
594- return [ctx-> device supportsFamily: (MTLGPUFamilyApple1 + family - 1 )];
604+ return [device supportsFamily: (MTLGPUFamilyApple1 + family - 1 )];
595605}
596606
597607void ggml_metal_capture_next_compute (ggml_metal_t ctx) {
0 commit comments