Skip to content

Commit d07ccb8

Browse files
committed
vulkan: multithread pipeline creation
1 parent ea40f60 commit d07ccb8

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

src/ggml-vulkan.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
14
#include "ggml-vulkan.h"
25
#include <vulkan/vulkan_core.h>
36
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
@@ -20,6 +23,8 @@
2023
#include <unordered_map>
2124
#include <memory>
2225
#include <mutex>
26+
#include <future>
27+
#include <thread>
2328

2429
#include "ggml.h"
2530
#include "ggml-backend-impl.h"
@@ -607,13 +612,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
607612

608613
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
609614

610-
static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
615+
// variables to track number of compiles in progress
616+
static uint32_t compile_count = 0;
617+
std::mutex compile_count_mutex;
618+
std::condition_variable compile_count_cond;
619+
620+
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
611621
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
612622
GGML_ASSERT(parameter_count > 0);
613623
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
614624

615-
std::lock_guard<std::mutex> guard(device->mutex);
616-
617625
pipeline = std::make_shared<vk_pipeline_struct>();
618626
pipeline->name = name;
619627
pipeline->parameter_count = parameter_count;
@@ -681,7 +689,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
681689
pipeline->layout);
682690
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
683691

684-
device->pipelines.insert({ pipeline->name, pipeline });
692+
{
693+
std::lock_guard<std::mutex> guard(device->mutex);
694+
device->pipelines.insert({ pipeline->name, pipeline });
695+
}
696+
697+
{
698+
std::lock_guard<std::mutex> guard(compile_count_mutex);
699+
assert(compile_count > 0);
700+
compile_count--;
701+
}
702+
compile_count_cond.notify_all();
685703
}
686704

687705
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -1190,6 +1208,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
11901208
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
11911209
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
11921210

1211+
std::vector<std::future<void>> compiles;
1212+
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
1213+
{
1214+
// wait until fewer than N compiles are in progress
1215+
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
1216+
std::unique_lock<std::mutex> guard(compile_count_mutex);
1217+
while (compile_count >= N) {
1218+
compile_count_cond.wait(guard);
1219+
}
1220+
compile_count++;
1221+
}
1222+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
1223+
};
1224+
11931225
if (device->fp16) {
11941226
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
11951227
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
@@ -1739,6 +1771,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
17391771
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
17401772

17411773
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1774+
1775+
for (auto &c : compiles) {
1776+
c.wait();
1777+
}
17421778
}
17431779

17441780
static vk_device ggml_vk_get_device(size_t idx) {

0 commit comments

Comments
 (0)