Skip to content

Commit 3d94e96

Browse files
authored
metal : fix data race in pipeline library (#17731)
1 parent 7feb0a1 commit 3d94e96

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
5050
}
5151

5252
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
53-
if (ppls->data.find(name) == ppls->data.end()) {
53+
if (ppls->data.find(name) == ppls->data.end()) {
5454
return nullptr;
5555
}
5656

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipelin
146146
id<MTLDevice> device;
147147

148148
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
149+
150+
NSLock * lock;
149151
};
150152

151153
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
@@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
296298

297299
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
298300

299-
res->obj = library;
300-
res->device = device;
301+
res->obj = library;
302+
res->device = device;
301303
res->pipelines = ggml_metal_pipelines_init();
304+
res->lock = [NSLock new];
302305

303306
return res;
304307
}
@@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
365368
res->obj = library;
366369
res->device = device;
367370
res->pipelines = ggml_metal_pipelines_init();
371+
res->lock = [NSLock new];
368372

369373
return res;
370374
}
@@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
380384

381385
ggml_metal_pipelines_free(lib->pipelines);
382386

387+
[lib->lock release];
388+
383389
free(lib);
384390
}
385391

386392
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
387-
return ggml_metal_pipelines_get(lib->pipelines, name);
393+
[lib->lock lock];
394+
395+
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
396+
397+
[lib->lock unlock];
398+
399+
return res;
388400
}
389401

390402
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
391-
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
392-
ggml_critical_section_start();
403+
[lib->lock lock];
393404

394-
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
405+
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
395406
if (res) {
396-
ggml_critical_section_end();
407+
[lib->lock unlock];
397408

398409
return res;
399410
}
@@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
414425
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
415426
}
416427
if (!mtl_function) {
417-
ggml_critical_section_end();
428+
[lib->lock unlock];
418429

419430
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
420431
if (error) {
@@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
433444
(int) res->obj.threadExecutionWidth);
434445

435446
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
436-
ggml_critical_section_end();
447+
[lib->lock unlock];
437448

438449
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
439450

@@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
443454
ggml_metal_pipelines_add(lib->pipelines, name, res);
444455
}
445456

446-
ggml_critical_section_end();
457+
[lib->lock unlock];
447458

448459
return res;
449460
}

0 commit comments

Comments
 (0)