4848 int mtl_device_ref_count;
4949 id <MTLLibrary > mtl_library;
5050
51+ NSLock * mtl_lock;
52+
5153 bool has_simdgroup_reduction;
5254 bool has_simdgroup_mm;
5355 bool has_residency_sets;
5456 bool has_bfloat;
5557 bool use_bfloat;
5658
59+ size_t max_size;
60+
5761 char name[128 ];
5862} g_ggml_ctx_dev_main = {
5963 /* .mtl_device =*/ nil ,
6064 /* .mtl_device_ref_count =*/ 0 ,
6165 /* .mtl_library =*/ nil ,
66+ /* .mtl_lock =*/ nil ,
6267 /* .has_simdgroup_reduction =*/ false ,
6368 /* .has_simdgroup_mm =*/ false ,
6469 /* .has_residency_sets =*/ false ,
6570 /* .has_bfloat =*/ false ,
6671 /* .use_bfloat =*/ false ,
72+ /* .max_size =*/ 0 ,
6773 /* .name =*/ " " ,
6874};
6975
7076// acquire
7177static id <MTLDevice > ggml_backend_metal_device_acq (struct ggml_backend_metal_device_context * ctx) {
7278 assert (ctx != NULL );
7379
80+ if (ctx->mtl_lock == nil ) {
81+ ctx->mtl_lock = [[NSLock alloc ] init ];
82+ }
83+
7484 if (ctx->mtl_device == nil ) {
7585 ctx->mtl_device = MTLCreateSystemDefaultDevice ();
7686 }
94104 ctx->use_bfloat = false ;
95105#endif
96106
107+ ctx->max_size = ctx->mtl_device .maxBufferLength ;
108+
97109 strncpy (ctx->name , [[ctx->mtl_device name ] UTF8String ], sizeof (ctx->name ) - 1 );
98110 }
99111
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110122 ctx->mtl_device_ref_count --;
111123
112124 if (ctx->mtl_device_ref_count == 0 ) {
125+ if (ctx->mtl_lock ) {
126+ [ctx->mtl_lock release ];
127+ ctx->mtl_lock = nil ;
128+ }
129+
113130 if (ctx->mtl_library ) {
114131 [ctx->mtl_library release ];
115132 ctx->mtl_library = nil ;
@@ -977,7 +994,7 @@ @implementation GGMLMetalClass
977994 struct ggml_backend_metal_context * ctx = calloc (1 , sizeof (struct ggml_backend_metal_context));
978995 struct ggml_backend_metal_device_context * ctx_dev = dev->context ;
979996
980- id <MTLDevice > device = ggml_backend_metal_device_acq ( ctx_dev) ;
997+ id <MTLDevice > device = ctx_dev-> mtl_device ;
981998
982999 GGML_LOG_INFO (" %s : picking default device: %s \n " , __func__, [[device name ] UTF8String ]);
9831000
@@ -991,9 +1008,16 @@ @implementation GGMLMetalClass
9911008 ctx->d_queue = dispatch_queue_create (" ggml-metal" , DISPATCH_QUEUE_CONCURRENT);
9921009
9931010 // load library
994- if (ctx_dev->mtl_library == nil ) {
995- ctx_dev->mtl_library = ggml_metal_load_library (device, ctx_dev->use_bfloat );
1011+ {
1012+ [ctx_dev->mtl_lock lock ];
1013+
1014+ if (ctx_dev->mtl_library == nil ) {
1015+ ctx_dev->mtl_library = ggml_metal_load_library (device, ctx_dev->use_bfloat );
1016+ }
1017+
1018+ [ctx_dev->mtl_lock unlock ];
9961019 }
1020+
9971021 id <MTLLibrary > metal_library = ctx_dev->mtl_library ;
9981022 if (metal_library == nil ) {
9991023 GGML_LOG_ERROR (" %s : error: metal library is nil\n " , __func__);
@@ -5284,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
52845308 }
52855309
52865310 ggml_backend_metal_buffer_rset_free (ctx);
5287- ggml_backend_metal_device_rel (buffer->buft ->device ->context );
52885311
52895312 if (ctx->owned ) {
52905313#if TARGET_OS_OSX
@@ -5393,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
53935416 }
53945417
53955418 struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device ->context ;
5396- id <MTLDevice > device = ggml_backend_metal_device_acq (ctx_dev);
5419+
5420+ GGML_ASSERT (ctx_dev->mtl_device != nil );
5421+
5422+ id <MTLDevice > device = ctx_dev->mtl_device ;
53975423
53985424 ctx->all_data = ggml_metal_host_malloc (size_aligned);
53995425 ctx->all_size = size_aligned;
@@ -5416,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
54165442 if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers [0 ].metal == nil )) {
54175443 GGML_LOG_ERROR (" %s : error: failed to allocate buffer, size = %8.2f MiB\n " , __func__, size_aligned / 1024.0 / 1024.0 );
54185444 free (ctx);
5419- ggml_backend_metal_device_rel (ctx_dev);
54205445 return NULL ;
54215446 }
54225447
54235448 if (!ggml_backend_metal_buffer_rset_init (ctx, ctx_dev, device)) {
54245449 GGML_LOG_ERROR (" %s : error: failed to initialize residency set\n " , __func__);
54255450 free (ctx);
5426- ggml_backend_metal_device_rel (ctx_dev);
54275451 return NULL ;
54285452 }
54295453
@@ -5434,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
54345458
54355459static size_t ggml_backend_metal_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
54365460 return 32 ;
5461+
54375462 GGML_UNUSED (buft);
54385463}
54395464
54405465static size_t ggml_backend_metal_buffer_type_get_max_size (ggml_backend_buffer_type_t buft) {
5441- id <MTLDevice > device = ggml_backend_metal_device_acq (buft->device ->context );
5442- const size_t max_size = device.maxBufferLength ;
5443- ggml_backend_metal_device_rel (buft->device ->context );
5466+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device ->context )->max_size ;
54445467
54455468 return max_size;
5446-
5447- GGML_UNUSED (buft);
54485469}
54495470
54505471static bool ggml_backend_metal_buffer_type_is_host (ggml_backend_buffer_type_t buft) {
@@ -5517,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55175538 }
55185539
55195540 struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5520- id <MTLDevice > device = ggml_backend_metal_device_acq (ctx_dev);
5541+
5542+ GGML_ASSERT (ctx_dev->mtl_device != nil );
5543+
5544+ id <MTLDevice > device = ctx_dev->mtl_device ;
55215545
55225546 // the buffer fits into the max buffer size allowed by the device
55235547 if (size_aligned <= device.maxBufferLength ) {
@@ -5573,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55735597 if (!ggml_backend_metal_buffer_rset_init (ctx, ctx_dev, device)) {
55745598 GGML_LOG_ERROR (" %s : error: failed to initialize residency set\n " , __func__);
55755599 free (ctx);
5576- ggml_backend_metal_device_rel (ctx_dev);
55775600 return NULL ;
55785601 }
55795602
@@ -5589,10 +5612,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55895612}
55905613
55915614static void ggml_backend_metal_free (ggml_backend_t backend) {
5592- struct ggml_backend_metal_context * ctx = backend->context ;
5593- struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
5615+ struct ggml_backend_metal_context * ctx = backend->context ;
55945616
5595- ggml_backend_metal_device_rel (ctx_dev);
55965617 ggml_metal_free (ctx);
55975618
55985619 free (backend);
@@ -5732,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
57325753
57335754 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
57345755
5756+ GGML_ASSERT (ctx_dev->mtl_device != nil );
5757+
57355758 return [ctx_dev->mtl_device supportsFamily: (MTLGPUFamilyApple1 + family - 1 )];
57365759}
57375760
@@ -5751,23 +5774,18 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
57515774}
57525775
57535776static const char * ggml_backend_metal_device_get_description (ggml_backend_dev_t dev) {
5754- // acq/rel just to populate ctx->name in case it hasn't been done yet
57555777 struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context ;
5756- ggml_backend_metal_device_acq (ctx_dev);
5757- ggml_backend_metal_device_rel (ctx_dev);
57585778
57595779 return ctx_dev->name ;
57605780}
57615781
57625782static void ggml_backend_metal_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
57635783 if (@available (macOS 10.12 , iOS 16.0 , *)) {
57645784 struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context ;
5765- id <MTLDevice > device = ggml_backend_metal_device_acq ( ctx_dev) ;
5785+ id <MTLDevice > device = ctx_dev-> mtl_device ;
57665786
57675787 *total = device.recommendedMaxWorkingSetSize ;
57685788 *free = *total - device.currentAllocatedSize ;
5769-
5770- ggml_backend_metal_device_rel (ctx_dev);
57715789 } else {
57725790 *free = 1 ;
57735791 *total = 1 ;
@@ -5845,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
58455863 }
58465864
58475865 struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context ;
5848- id <MTLDevice > device = ggml_backend_metal_device_acq (ctx_dev);
5866+
5867+ GGML_ASSERT (ctx_dev->mtl_device != nil );
5868+
5869+ id <MTLDevice > device = ctx_dev->mtl_device ;
58495870
58505871 // the buffer fits into the max buffer size allowed by the device
58515872 if (size_aligned <= device.maxBufferLength ) {
@@ -5901,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
59015922 if (!ggml_backend_metal_buffer_rset_init (ctx, ctx_dev, device)) {
59025923 GGML_LOG_ERROR (" %s : error: failed to initialize residency set\n " , __func__);
59035924 free (ctx);
5904- ggml_backend_metal_device_rel (ctx_dev);
59055925 return NULL ;
59065926 }
59075927
@@ -5915,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
59155935}
59165936
59175937static bool ggml_backend_metal_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5918- return buft->iface .get_name == ggml_backend_metal_buffer_type_get_name ||
5919- buft->iface .get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5938+ return
5939+ buft->iface .get_name == ggml_backend_metal_buffer_type_get_name ||
5940+ buft->iface .get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
59205941
59215942 GGML_UNUSED (dev);
59225943}
@@ -6001,8 +6022,19 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
60016022 /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
60026023};
60036024
6025+ // called upon program exit
6026+ static void ggml_metal_cleanup (void ) {
6027+ ggml_backend_metal_device_rel (&g_ggml_ctx_dev_main);
6028+ }
6029+
6030+ // TODO: make thread-safe
60046031ggml_backend_reg_t ggml_backend_metal_reg (void ) {
6005- // TODO: make this thread-safe somehow?
6032+ ggml_backend_metal_device_acq (&g_ggml_ctx_dev_main);
6033+
6034+ // register cleanup callback
6035+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6036+ atexit (ggml_metal_cleanup);
6037+
60066038 {
60076039 g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
60086040 /* .api_version = */ GGML_BACKEND_API_VERSION,
0 commit comments