@@ -146,6 +146,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipelin
146146 id <MTLDevice > device;
147147
148148 ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
149+
150+ NSLock * lock;
149151};
150152
151153ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev) {
@@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
296298
297299 ggml_metal_library_t res = calloc (1 , sizeof (struct ggml_metal_library));
298300
299- res->obj = library;
300- res->device = device;
301+ res->obj = library;
302+ res->device = device;
301303 res->pipelines = ggml_metal_pipelines_init ();
304+ res->lock = [NSLock new ];
302305
303306 return res;
304307}
@@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
365368 res->obj = library;
366369 res->device = device;
367370 res->pipelines = ggml_metal_pipelines_init ();
371+ res->lock = [NSLock new ];
368372
369373 return res;
370374}
@@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
380384
381385 ggml_metal_pipelines_free (lib->pipelines );
382386
387+ [lib->lock release ];
388+
383389 free (lib);
384390}
385391
386392ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name) {
387- return ggml_metal_pipelines_get (lib->pipelines , name);
393+ [lib->lock lock ];
394+
395+ ggml_metal_pipeline_t res = ggml_metal_pipelines_get (lib->pipelines , name);
396+
397+ [lib->lock unlock ];
398+
399+ return res;
388400}
389401
390402ggml_metal_pipeline_t ggml_metal_library_compile_pipeline (ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
391- // note: the pipelines are cached in the library per device, so they are shared across all metal contexts
392- ggml_critical_section_start ();
403+ [lib->lock lock ];
393404
394- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
405+ ggml_metal_pipeline_t res = ggml_metal_pipelines_get (lib-> pipelines , name);
395406 if (res) {
396- ggml_critical_section_end () ;
407+ [lib->lock unlock ] ;
397408
398409 return res;
399410 }
@@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
414425 mtl_function = [lib->obj newFunctionWithName: base_func constantValues: cv->obj error: &error];
415426 }
416427 if (!mtl_function) {
417- ggml_critical_section_end () ;
428+ [lib->lock unlock ] ;
418429
419430 GGML_LOG_ERROR (" %s : failed to compile pipeline: base = '%s ', name = '%s '\n " , __func__, base, name);
420431 if (error) {
@@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
433444 (int ) res->obj .threadExecutionWidth );
434445
435446 if (res->obj .maxTotalThreadsPerThreadgroup == 0 || res->obj .threadExecutionWidth == 0 ) {
436- ggml_critical_section_end () ;
447+ [lib->lock unlock ] ;
437448
438449 GGML_LOG_ERROR (" %s : incompatible pipeline %s \n " , __func__, name);
439450
@@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
443454 ggml_metal_pipelines_add (lib->pipelines , name, res);
444455 }
445456
446- ggml_critical_section_end () ;
457+ [lib->lock unlock ] ;
447458
448459 return res;
449460}
0 commit comments