| 
 | 1 | +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.  | 
 | 2 | +// SPDX-License-Identifier: MIT  | 
 | 3 | + | 
1 | 4 | #include "ggml-vulkan.h"  | 
2 | 5 | #include <vulkan/vulkan_core.h>  | 
3 | 6 | #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)  | 
 | 
20 | 23 | #include <unordered_map>  | 
21 | 24 | #include <memory>  | 
22 | 25 | #include <mutex>  | 
 | 26 | +#include <future>  | 
 | 27 | +#include <thread>  | 
23 | 28 | 
 
  | 
24 | 29 | #include "ggml.h"  | 
25 | 30 | #include "ggml-backend-impl.h"  | 
@@ -607,13 +612,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx  | 
607 | 612 | 
 
  | 
608 | 613 | GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);  | 
609 | 614 | 
 
  | 
610 |  | -static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {  | 
 | 615 | +// variables to track number of compiles in progress  | 
 | 616 | +static uint32_t compile_count = 0;  | 
 | 617 | +std::mutex compile_count_mutex;  | 
 | 618 | +std::condition_variable compile_count_cond;  | 
 | 619 | + | 
 | 620 | +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {  | 
611 | 621 |     VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");  | 
612 | 622 |     GGML_ASSERT(parameter_count > 0);  | 
613 | 623 |     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT  | 
614 | 624 | 
 
  | 
615 |  | -    std::lock_guard<std::mutex> guard(device->mutex);  | 
616 |  | - | 
617 | 625 |     pipeline = std::make_shared<vk_pipeline_struct>();  | 
618 | 626 |     pipeline->name = name;  | 
619 | 627 |     pipeline->parameter_count = parameter_count;  | 
@@ -681,7 +689,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co  | 
681 | 689 |         pipeline->layout);  | 
682 | 690 |     pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;  | 
683 | 691 | 
 
  | 
684 |  | -    device->pipelines.insert({ pipeline->name, pipeline });  | 
 | 692 | +    {  | 
 | 693 | +        std::lock_guard<std::mutex> guard(device->mutex);  | 
 | 694 | +        device->pipelines.insert({ pipeline->name, pipeline });  | 
 | 695 | +    }  | 
 | 696 | + | 
 | 697 | +    {  | 
 | 698 | +        std::lock_guard<std::mutex> guard(compile_count_mutex);  | 
 | 699 | +        assert(compile_count > 0);  | 
 | 700 | +        compile_count--;  | 
 | 701 | +    }  | 
 | 702 | +    compile_count_cond.notify_all();  | 
685 | 703 | }  | 
686 | 704 | 
 
  | 
687 | 705 | static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {  | 
@@ -1190,6 +1208,20 @@ static void ggml_vk_load_shaders(vk_device& device) {  | 
1190 | 1208 |     device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();  | 
1191 | 1209 |     device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();  | 
1192 | 1210 | 
 
  | 
 | 1211 | +    std::vector<std::future<void>> compiles;  | 
 | 1212 | +    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {  | 
 | 1213 | +        {  | 
 | 1214 | +            // wait until fewer than N compiles are in progress  | 
 | 1215 | +            uint32_t N = std::max(1u, std::thread::hardware_concurrency());  | 
 | 1216 | +            std::unique_lock<std::mutex> guard(compile_count_mutex);  | 
 | 1217 | +            while (compile_count >= N) {  | 
 | 1218 | +                compile_count_cond.wait(guard);  | 
 | 1219 | +            }  | 
 | 1220 | +            compile_count++;  | 
 | 1221 | +        }  | 
 | 1222 | +        compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));  | 
 | 1223 | +    };  | 
 | 1224 | + | 
1193 | 1225 |     if (device->fp16) {  | 
1194 | 1226 |         ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);  | 
1195 | 1227 |         ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);  | 
@@ -1739,6 +1771,10 @@ static void ggml_vk_load_shaders(vk_device& device) {  | 
1739 | 1771 |     ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);  | 
1740 | 1772 | 
 
  | 
1741 | 1773 |     ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);  | 
 | 1774 | + | 
 | 1775 | +    for (auto &c : compiles) {  | 
 | 1776 | +        c.wait();  | 
 | 1777 | +    }  | 
1742 | 1778 | }  | 
1743 | 1779 | 
 
  | 
1744 | 1780 | static vk_device ggml_vk_get_device(size_t idx) {  | 
 | 
0 commit comments