@@ -130,9 +130,9 @@ struct vk_pipeline_struct {
130130 // true if fields have been set by ggml_vk_create_pipeline
131131 bool initialized {};
132132 // set to true to request the pipeline is compiled
133- bool needed {};
133+ std::atomic< bool> needed {};
134134 // set to true when the shader has been compiled
135- bool compiled {};
135+ std::atomic< bool> compiled {};
136136 // number of registers used, extracted from pipeline executable properties
137137 uint32_t register_count {};
138138};
@@ -1842,10 +1842,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
18421842 }
18431843 }
18441844
1845- {
1846- std::lock_guard<std::recursive_mutex> guard(device->mutex);
1847- device->all_pipelines.push_back(pipeline);
1848- }
1845+ device->all_pipelines.push_back(pipeline);
18491846
18501847 {
18511848 std::lock_guard<std::mutex> guard(compile_count_mutex);
@@ -2536,6 +2533,7 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
25362533static void ggml_vk_load_shaders(vk_device& device) {
25372534 VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
25382535
2536+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
25392537 // some shaders have a minimum subgroup size
25402538 const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
25412539 const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@@ -2729,6 +2727,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27292727 if (!pipeline->needed || pipeline->compiled) {
27302728 return;
27312729 }
2730+ // TODO: We're no longer benefitting from the async compiles (shaders are
2731+ // compiled individually, as needed) and this complexity can be removed.
27322732 {
27332733 // wait until fewer than N compiles are in progress
27342734 uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -7914,12 +7914,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
79147914
79157915 vk_pipeline pipeline = nullptr;
79167916
7917- auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
7918- auto it = pipelines.find(fa_pipeline_state);
7919- if (it != pipelines.end()) {
7920- pipeline = it->second;
7921- } else {
7922- pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
7917+ {
7918+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
7919+ auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
7920+ auto it = pipelines.find(fa_pipeline_state);
7921+ if (it != pipelines.end()) {
7922+ pipeline = it->second;
7923+ } else {
7924+ pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
7925+ }
79237926 }
79247927
79257928 assert(pipeline);
0 commit comments