@@ -277,8 +277,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
277277};
278278
279279struct  ggml_backend_metal_context {
280-     struct  ggml_backend_metal_device_context ctx_dev;
281- 
282280    id <MTLCommandQueue > queue;
283281
284282    dispatch_queue_t  d_queue;
@@ -343,7 +341,7 @@ @implementation GGMLMetalClass
343341    return  data;
344342}
345343
346- static  struct  ggml_backend_metal_context * ggml_metal_init (void ) {
344+ static  struct  ggml_backend_metal_context * ggml_metal_init (ggml_backend_dev_t  dev ) {
347345    GGML_LOG_INFO (" %s : allocating\n "  , __func__);
348346
349347#if  TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -357,8 +355,9 @@ @implementation GGMLMetalClass
357355
358356    //  init context
359357    struct  ggml_backend_metal_context * ctx = calloc (1 , sizeof (struct  ggml_backend_metal_context));
358+     struct  ggml_backend_metal_device_context * ctx_dev = dev->context ;
360359
361-     id <MTLDevice > device = ggml_backend_metal_device_acq (&ctx-> ctx_dev );
360+     id <MTLDevice > device = ggml_backend_metal_device_acq (ctx_dev);
362361    GGML_LOG_INFO (" %s : picking default device: %s \n "  , __func__, [[device name ] UTF8String ]);
363362
364363    ctx->queue   = [device newCommandQueue ];
@@ -482,9 +481,9 @@ @implementation GGMLMetalClass
482481        }
483482    }
484483
485-     GGML_LOG_INFO (" %s : simdgroup reduction support   = %s \n "  , __func__, ctx-> ctx_dev . support_simdgroup_reduction  ? " true"   : " false"  );
486-     GGML_LOG_INFO (" %s : simdgroup matrix mul. support = %s \n "  , __func__, ctx-> ctx_dev . support_simdgroup_mm  ? " true"   : " false"  );
487-     GGML_LOG_INFO (" %s : hasUnifiedMemory              = %s \n "  , __func__, ctx-> ctx_dev . mtl_device .hasUnifiedMemory  ? " true"   : " false"  );
484+     GGML_LOG_INFO (" %s : simdgroup reduction support   = %s \n "  , __func__, ctx_dev-> support_simdgroup_reduction  ? " true"   : " false"  );
485+     GGML_LOG_INFO (" %s : simdgroup matrix mul. support = %s \n "  , __func__, ctx_dev-> support_simdgroup_mm  ? " true"   : " false"  );
486+     GGML_LOG_INFO (" %s : hasUnifiedMemory              = %s \n "  , __func__, ctx_dev-> mtl_device .hasUnifiedMemory  ? " true"   : " false"  );
488487
489488    ctx->capture_next_compute  = false ;
490489    ctx->capture_started  = false ;
@@ -536,8 +535,8 @@ @implementation GGMLMetalClass
536535            GGML_LOG_WARN (" %s : skipping %-40s  (not supported)\n "  , __func__, " kernel_"  #name); \
537536        }
538537
539-         const  bool  support_simdgroup_mm        = ctx-> ctx_dev . support_simdgroup_mm ;
540-         const  bool  support_simdgroup_reduction = ctx-> ctx_dev . support_simdgroup_reduction ;
538+         const  bool  support_simdgroup_mm        = ctx_dev-> support_simdgroup_mm ;
539+         const  bool  support_simdgroup_reduction = ctx_dev-> support_simdgroup_reduction ;
541540
542541        //  simd_sum and simd_max requires MTLGPUFamilyApple7
543542
@@ -740,7 +739,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
740739    }
741740
742741    [ctx->queue release ];
743-     ggml_backend_metal_device_rel (&ctx->ctx_dev );
744742
745743    dispatch_release (ctx->d_queue );
746744
@@ -798,15 +796,15 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
798796    return  nil ;
799797}
800798
801- static  bool  ggml_metal_supports_op (const  struct  ggml_backend_metal_context  * ctx , const  struct  ggml_tensor * op) {
799+ static  bool  ggml_metal_supports_op (const  struct  ggml_backend_metal_device_context  * ctx_dev , const  struct  ggml_tensor * op) {
802800    for  (size_t  i = 0 , n = 3 ; i < n; ++i) {
803801        if  (op->src [i] != NULL  && op->src [i]->type  == GGML_TYPE_BF16) {
804802            return  false ;
805803        }
806804    }
807805
808-     const  bool  support_simdgroup_mm        = ctx-> ctx_dev . support_simdgroup_mm ;
809-     const  bool  support_simdgroup_reduction = ctx-> ctx_dev . support_simdgroup_reduction ;
806+     const  bool  support_simdgroup_mm        = ctx_dev-> support_simdgroup_mm ;
807+     const  bool  support_simdgroup_reduction = ctx_dev-> support_simdgroup_reduction ;
810808
811809    switch  (op->op ) {
812810        case  GGML_OP_UNARY:
@@ -921,9 +919,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
921919}
922920
923921static  void  ggml_metal_encode_node (
924-      struct  ggml_backend_metal_context * ctx ,
922+                          ggml_backend_t    backend ,
925923                                   int    idx,
926924          id <MTLComputeCommandEncoder >   encoder) {
925+     struct  ggml_backend_metal_context        * ctx     = backend->context ;
926+     struct  ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
927+ 
927928    struct  ggml_cgraph * gf = ctx->gf ;
928929
929930    struct  ggml_tensor * node = ggml_graph_node (gf, idx);
@@ -953,7 +954,7 @@ static void ggml_metal_encode_node(
953954            } break ;
954955    }
955956
956-     if  (!ggml_metal_supports_op (ctx , dst)) {
957+     if  (!ggml_metal_supports_op (ctx_dev , dst)) {
957958        GGML_LOG_ERROR (" %s : error: unsupported op '%s '\n "  , __func__, ggml_op_desc (dst));
958959        GGML_ABORT (" unsupported op"  );
959960    }
@@ -1026,7 +1027,7 @@ static void ggml_metal_encode_node(
10261027    //             dst->name);
10271028    // }
10281029
1029-     id <MTLDevice > device = ctx-> ctx_dev . mtl_device ;
1030+     id <MTLDevice > device = ctx_dev-> mtl_device ;
10301031
10311032    switch  (dst->op ) {
10321033        case  GGML_OP_CONCAT:
@@ -3015,8 +3016,11 @@ static void ggml_metal_encode_node(
30153016}
30163017
30173018static  enum  ggml_status ggml_metal_graph_compute (
3018-         struct  ggml_backend_metal_context * ctx,
3019-                        struct  ggml_cgraph * gf) {
3019+             ggml_backend_t    backend,
3020+         struct  ggml_cgraph * gf) {
3021+     struct  ggml_backend_metal_context        * ctx     = backend->context ;
3022+     struct  ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
3023+ 
30203024    //  number of nodes encoded by the main thread (empirically determined)
30213025    const  int  n_main = 128 ;
30223026
@@ -3044,7 +3048,7 @@ static enum ggml_status ggml_metal_graph_compute(
30443048
30453049            if  (!ctx->capture_started ) {
30463050                //  create capture scope
3047-                 ctx->capture_scope  = [[MTLCaptureManager  sharedCaptureManager ] newCaptureScopeWithDevice: ctx->ctx_dev. mtl_device];
3051+                 ctx->capture_scope  = [[MTLCaptureManager  sharedCaptureManager ] newCaptureScopeWithDevice: ctx_dev-> mtl_device];
30483052
30493053                MTLCaptureDescriptor  * descriptor = [MTLCaptureDescriptor  new ];
30503054                descriptor.captureObject  = ctx->capture_scope ;
@@ -3087,7 +3091,7 @@ static enum ggml_status ggml_metal_graph_compute(
30873091                    [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (gf, idx)) encoding: NSUTF8StringEncoding]];
30883092                }
30893093
3090-                 ggml_metal_encode_node (ctx , idx, encoder);
3094+                 ggml_metal_encode_node (backend , idx, encoder);
30913095
30923096                if  (should_capture) {
30933097                    [encoder popDebugGroup ];
@@ -3462,6 +3466,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
34623466
34633467static  void  ggml_backend_metal_free (ggml_backend_t  backend) {
34643468    struct  ggml_backend_metal_context * ctx = (struct  ggml_backend_metal_context *)backend->context ;
3469+     struct  ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
3470+     ggml_backend_metal_device_rel (ctx_dev);
34653471    ggml_metal_free (ctx);
34663472    free (backend);
34673473}
@@ -3473,9 +3479,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
34733479}
34743480
34753481static  enum  ggml_status ggml_backend_metal_graph_compute (ggml_backend_t  backend, struct  ggml_cgraph * cgraph) {
3476-     struct  ggml_backend_metal_context * metal_ctx = (struct  ggml_backend_metal_context *)backend->context ;
3477- 
3478-     return  ggml_metal_graph_compute (metal_ctx, cgraph);
3482+     return  ggml_metal_graph_compute (backend, cgraph);
34793483}
34803484
34813485static  void  ggml_backend_metal_set_n_cb (ggml_backend_t  backend, int  n_cb) {
@@ -3522,8 +3526,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
35223526    return  &guid;
35233527}
35243528
3529+ //  TODO: remove in the future
35253530ggml_backend_t  ggml_backend_metal_init (void ) {
3526-     struct  ggml_backend_metal_context * ctx = ggml_metal_init ();
3531+     ggml_backend_dev_t  dev = ggml_backend_reg_dev_get (ggml_backend_metal_reg (), 0 );
3532+ 
3533+     struct  ggml_backend_metal_context * ctx = ggml_metal_init (dev);
35273534    if  (ctx == NULL ) {
35283535        GGML_LOG_ERROR (" %s : error: failed to allocate context\n "  , __func__);
35293536        return  NULL ;
@@ -3534,7 +3541,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
35343541    *backend = (struct  ggml_backend) {
35353542        /*  .guid      = */   ggml_backend_metal_guid (),
35363543        /*  .interface = */   ggml_backend_metal_i,
3537-         /*  .device    = */   &g_ggml_backend_metal_device ,
3544+         /*  .device    = */   dev ,
35383545        /*  .context   = */   ctx,
35393546    };
35403547
@@ -3559,9 +3566,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
35593566bool  ggml_backend_metal_supports_family (ggml_backend_t  backend, int  family) {
35603567    GGML_ASSERT (ggml_backend_is_metal (backend));
35613568
3562-     struct  ggml_backend_metal_context  * ctx  = ( struct  ggml_backend_metal_context *) backend->context ;
3569+     struct  ggml_backend_metal_device_context  * ctx_dev  = backend-> device ->context ;
35633570
3564-     return  [ctx->ctx_dev. mtl_device supportsFamily: (MTLGPUFamilyApple1  + family - 1 )];
3571+     return  [ctx_dev-> mtl_device supportsFamily: (MTLGPUFamilyApple1  + family - 1 )];
35653572}
35663573
35673574void  ggml_backend_metal_capture_next_compute (ggml_backend_t  backend) {
@@ -3623,9 +3630,25 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g
36233630}
36243631
36253632static  ggml_backend_t  ggml_backend_metal_device_init (ggml_backend_dev_t  dev, const  char  * params) {
3626-     return  ggml_backend_metal_init ();
3633+     struct  ggml_backend_metal_context * ctx = ggml_metal_init (dev);
3634+     if  (ctx == NULL ) {
3635+         GGML_LOG_ERROR (" %s : error: failed to allocate context\n "  , __func__);
3636+         return  NULL ;
3637+     }
3638+ 
3639+     ggml_backend_t  backend = malloc (sizeof (struct  ggml_backend));
3640+ 
3641+     *backend = (struct  ggml_backend) {
3642+         /*  .guid      = */   ggml_backend_metal_guid (),
3643+         /*  .interface = */   ggml_backend_metal_i,
3644+         /*  .device    = */   dev,
3645+         /*  .context   = */   ctx,
3646+     };
3647+ 
3648+     ggml_backend_metal_set_n_cb (backend, 1 );
3649+ 
3650+     return  backend;
36273651
3628-     GGML_UNUSED (dev);
36293652    GGML_UNUSED (params);
36303653}
36313654
@@ -3715,9 +3738,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
37153738}
37163739
37173740static  bool  ggml_backend_metal_device_supports_op (ggml_backend_dev_t  dev, const  struct  ggml_tensor * op) {
3718-     struct  ggml_backend_metal_context  * metal_ctx  = ( struct  ggml_backend_metal_context *) dev->context ;
3741+     struct  ggml_backend_metal_device_context  * ctx_dev  = dev->context ;
37193742
3720-     return  ggml_metal_supports_op (metal_ctx , op);
3743+     return  ggml_metal_supports_op (ctx_dev , op);
37213744}
37223745
37233746static  bool  ggml_backend_metal_device_supports_buft (ggml_backend_dev_t  dev, ggml_backend_buffer_type_t  buft) {
0 commit comments