Skip to content

Commit d6fe40f

Browse files
authored
vulkan: Fix test-thread-safety crashes (ggml-org#17024)
The std::map pipeline_flash_attn_f32_f16 could be searched and inserted at the same time, which needs to hold the lock. To be safe, hold the lock for all of ggml_vk_load_shaders.
1 parent e14e842 commit d6fe40f

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ struct vk_pipeline_struct {
130130
// true if fields have been set by ggml_vk_create_pipeline
131131
bool initialized {};
132132
// set to true to request the pipeline is compiled
133-
bool needed {};
133+
std::atomic<bool> needed {};
134134
// set to true when the shader has been compiled
135-
bool compiled {};
135+
std::atomic<bool> compiled {};
136136
// number of registers used, extracted from pipeline executable properties
137137
uint32_t register_count {};
138138
};
@@ -1842,10 +1842,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
18421842
}
18431843
}
18441844

1845-
{
1846-
std::lock_guard<std::recursive_mutex> guard(device->mutex);
1847-
device->all_pipelines.push_back(pipeline);
1848-
}
1845+
device->all_pipelines.push_back(pipeline);
18491846

18501847
{
18511848
std::lock_guard<std::mutex> guard(compile_count_mutex);
@@ -2536,6 +2533,7 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
25362533
static void ggml_vk_load_shaders(vk_device& device) {
25372534
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
25382535

2536+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
25392537
// some shaders have a minimum subgroup size
25402538
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
25412539
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@@ -2729,6 +2727,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27292727
if (!pipeline->needed || pipeline->compiled) {
27302728
return;
27312729
}
2730+
// TODO: We're no longer benefitting from the async compiles (shaders are
2731+
// compiled individually, as needed) and this complexity can be removed.
27322732
{
27332733
// wait until fewer than N compiles are in progress
27342734
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -7914,12 +7914,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
79147914

79157915
vk_pipeline pipeline = nullptr;
79167916

7917-
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
7918-
auto it = pipelines.find(fa_pipeline_state);
7919-
if (it != pipelines.end()) {
7920-
pipeline = it->second;
7921-
} else {
7922-
pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
7917+
{
7918+
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
7919+
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
7920+
auto it = pipelines.find(fa_pipeline_state);
7921+
if (it != pipelines.end()) {
7922+
pipeline = it->second;
7923+
} else {
7924+
pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
7925+
}
79237926
}
79247927

79257928
assert(pipeline);

0 commit comments

Comments
 (0)