Skip to content

Commit 70ff50d

Browse files
committed
metal : avoid reference of device context in the backend context
ggml-ci
1 parent 34e0e6e commit 70ff50d

File tree

1 file changed

+54
-31
lines changed

1 file changed

+54
-31
lines changed

ggml/src/ggml-metal.m

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
277277
};
278278

279279
struct ggml_backend_metal_context {
280-
struct ggml_backend_metal_device_context ctx_dev;
281-
282280
id<MTLCommandQueue> queue;
283281

284282
dispatch_queue_t d_queue;
@@ -343,7 +341,7 @@ @implementation GGMLMetalClass
343341
return data;
344342
}
345343

346-
static struct ggml_backend_metal_context * ggml_metal_init(void) {
344+
static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
347345
GGML_LOG_INFO("%s: allocating\n", __func__);
348346

349347
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -357,8 +355,9 @@ @implementation GGMLMetalClass
357355

358356
// init context
359357
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
358+
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
360359

361-
id<MTLDevice> device = ggml_backend_metal_device_acq(&ctx->ctx_dev);
360+
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
362361
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
363362

364363
ctx->queue = [device newCommandQueue];
@@ -482,9 +481,9 @@ @implementation GGMLMetalClass
482481
}
483482
}
484483

485-
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_reduction ? "true" : "false");
486-
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->ctx_dev.support_simdgroup_mm ? "true" : "false");
487-
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->ctx_dev.mtl_device.hasUnifiedMemory ? "true" : "false");
484+
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
485+
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
486+
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
488487

489488
ctx->capture_next_compute = false;
490489
ctx->capture_started = false;
@@ -536,8 +535,8 @@ @implementation GGMLMetalClass
536535
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
537536
}
538537

539-
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
540-
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
538+
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
539+
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
541540

542541
// simd_sum and simd_max requires MTLGPUFamilyApple7
543542

@@ -740,7 +739,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
740739
}
741740

742741
[ctx->queue release];
743-
ggml_backend_metal_device_rel(&ctx->ctx_dev);
744742

745743
dispatch_release(ctx->d_queue);
746744

@@ -798,15 +796,15 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
798796
return nil;
799797
}
800798

801-
static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
799+
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
802800
for (size_t i = 0, n = 3; i < n; ++i) {
803801
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
804802
return false;
805803
}
806804
}
807805

808-
const bool support_simdgroup_mm = ctx->ctx_dev.support_simdgroup_mm;
809-
const bool support_simdgroup_reduction = ctx->ctx_dev.support_simdgroup_reduction;
806+
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
807+
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
810808

811809
switch (op->op) {
812810
case GGML_OP_UNARY:
@@ -921,9 +919,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
921919
}
922920

923921
static void ggml_metal_encode_node(
924-
struct ggml_backend_metal_context * ctx,
922+
ggml_backend_t backend,
925923
int idx,
926924
id<MTLComputeCommandEncoder> encoder) {
925+
struct ggml_backend_metal_context * ctx = backend->context;
926+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
927+
927928
struct ggml_cgraph * gf = ctx->gf;
928929

929930
struct ggml_tensor * node = ggml_graph_node(gf, idx);
@@ -953,7 +954,7 @@ static void ggml_metal_encode_node(
953954
} break;
954955
}
955956

956-
if (!ggml_metal_supports_op(ctx, dst)) {
957+
if (!ggml_metal_supports_op(ctx_dev, dst)) {
957958
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
958959
GGML_ABORT("unsupported op");
959960
}
@@ -1026,7 +1027,7 @@ static void ggml_metal_encode_node(
10261027
// dst->name);
10271028
//}
10281029

1029-
id<MTLDevice> device = ctx->ctx_dev.mtl_device;
1030+
id<MTLDevice> device = ctx_dev->mtl_device;
10301031

10311032
switch (dst->op) {
10321033
case GGML_OP_CONCAT:
@@ -3015,8 +3016,11 @@ static void ggml_metal_encode_node(
30153016
}
30163017

30173018
static enum ggml_status ggml_metal_graph_compute(
3018-
struct ggml_backend_metal_context * ctx,
3019-
struct ggml_cgraph * gf) {
3019+
ggml_backend_t backend,
3020+
struct ggml_cgraph * gf) {
3021+
struct ggml_backend_metal_context * ctx = backend->context;
3022+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
3023+
30203024
// number of nodes encoded by the main thread (empirically determined)
30213025
const int n_main = 128;
30223026

@@ -3044,7 +3048,7 @@ static enum ggml_status ggml_metal_graph_compute(
30443048

30453049
if (!ctx->capture_started) {
30463050
// create capture scope
3047-
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->ctx_dev.mtl_device];
3051+
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
30483052

30493053
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
30503054
descriptor.captureObject = ctx->capture_scope;
@@ -3087,7 +3091,7 @@ static enum ggml_status ggml_metal_graph_compute(
30873091
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
30883092
}
30893093

3090-
ggml_metal_encode_node(ctx, idx, encoder);
3094+
ggml_metal_encode_node(backend, idx, encoder);
30913095

30923096
if (should_capture) {
30933097
[encoder popDebugGroup];
@@ -3462,6 +3466,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
34623466

34633467
static void ggml_backend_metal_free(ggml_backend_t backend) {
34643468
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
3469+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
3470+
ggml_backend_metal_device_rel(ctx_dev);
34653471
ggml_metal_free(ctx);
34663472
free(backend);
34673473
}
@@ -3473,9 +3479,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
34733479
}
34743480

34753481
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
3476-
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
3477-
3478-
return ggml_metal_graph_compute(metal_ctx, cgraph);
3482+
return ggml_metal_graph_compute(backend, cgraph);
34793483
}
34803484

34813485
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -3522,8 +3526,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
35223526
return &guid;
35233527
}
35243528

3529+
// TODO: remove in the future
35253530
ggml_backend_t ggml_backend_metal_init(void) {
3526-
struct ggml_backend_metal_context * ctx = ggml_metal_init();
3531+
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
3532+
3533+
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
35273534
if (ctx == NULL) {
35283535
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
35293536
return NULL;
@@ -3534,7 +3541,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
35343541
*backend = (struct ggml_backend) {
35353542
/* .guid = */ ggml_backend_metal_guid(),
35363543
/* .interface = */ ggml_backend_metal_i,
3537-
/* .device = */ &g_ggml_backend_metal_device,
3544+
/* .device = */ dev,
35383545
/* .context = */ ctx,
35393546
};
35403547

@@ -3559,9 +3566,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
35593566
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
35603567
GGML_ASSERT(ggml_backend_is_metal(backend));
35613568

3562-
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
3569+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
35633570

3564-
return [ctx->ctx_dev.mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3571+
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
35653572
}
35663573

35673574
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
@@ -3623,9 +3630,25 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g
36233630
}
36243631

36253632
static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
3626-
return ggml_backend_metal_init();
3633+
struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
3634+
if (ctx == NULL) {
3635+
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3636+
return NULL;
3637+
}
3638+
3639+
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
3640+
3641+
*backend = (struct ggml_backend) {
3642+
/* .guid = */ ggml_backend_metal_guid(),
3643+
/* .interface = */ ggml_backend_metal_i,
3644+
/* .device = */ dev,
3645+
/* .context = */ ctx,
3646+
};
3647+
3648+
ggml_backend_metal_set_n_cb(backend, 1);
3649+
3650+
return backend;
36273651

3628-
GGML_UNUSED(dev);
36293652
GGML_UNUSED(params);
36303653
}
36313654

@@ -3715,9 +3738,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
37153738
}
37163739

37173740
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
3718-
struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)dev->context;
3741+
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
37193742

3720-
return ggml_metal_supports_op(metal_ctx, op);
3743+
return ggml_metal_supports_op(ctx_dev, op);
37213744
}
37223745

37233746
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {

0 commit comments

Comments
 (0)