@@ -85,6 +85,10 @@ struct vk_pipeline_struct {
8585 uint32_t parameter_count;
8686 std::array<uint32_t , 3 > wg_denoms;
8787 uint32_t align;
88+ // set to true to request the pipeline is compiled after the dryrun
89+ bool needed {};
90+ // set to true when the shader has been compiled
91+ bool compiled {};
8892};
8993
9094typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
@@ -186,16 +190,19 @@ struct vk_device_struct {
186190 bool mul_mat_id_m;
187191 bool mul_mat_id_s;
188192
189- vk_matmul_pipeline pipeline_matmul_f32;
190- vk_matmul_pipeline pipeline_matmul_f32_f16;
193+ // set to true to indicate that some shaders need to be compiled after the dryrun
194+ bool need_compiles {};
195+
196+ vk_matmul_pipeline pipeline_matmul_f32 {};
197+ vk_matmul_pipeline pipeline_matmul_f32_f16 {};
191198 vk_matmul_pipeline2 pipeline_matmul_f16;
192199 vk_matmul_pipeline2 pipeline_matmul_f16_f32;
193200 vk_pipeline pipeline_matmul_split_k_reduce;
194201
195202 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
196203 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
197204
198- vk_matmul_pipeline pipeline_matmul_id_f32;
205+ vk_matmul_pipeline pipeline_matmul_id_f32 {} ;
199206 vk_matmul_pipeline2 pipeline_matmul_id_f16;
200207 vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
201208
@@ -776,13 +783,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
776783 GGML_ASSERT (parameter_count > 0 );
777784 GGML_ASSERT (wg_denoms[0 ] > 0 && wg_denoms[1 ] > 0 && wg_denoms[2 ] > 0 ); // NOLINT
778785
779- pipeline = std::make_shared<vk_pipeline_struct>();
780- pipeline->name = name;
781- pipeline->parameter_count = parameter_count;
782- pipeline->push_constant_size = push_constant_size;
783- pipeline->wg_denoms = wg_denoms;
784- pipeline->align = align;
785-
786786 vk::ShaderModuleCreateInfo shader_module_create_info ({}, spv_size, reinterpret_cast <const uint32_t *>(spv_data));
787787 pipeline->shader_module = device->device .createShaderModule (shader_module_create_info);
788788
@@ -865,6 +865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
865865 }
866866
867867 pipeline->pipeline = device->device .createComputePipeline (VK_NULL_HANDLE, compute_pipeline_create_info).value ;
868+ pipeline->compiled = true ;
868869
869870 {
870871 std::lock_guard<std::mutex> guard (device->mutex );
@@ -875,12 +876,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
875876 std::lock_guard<std::mutex> guard (compile_count_mutex);
876877 assert (compile_count > 0 );
877878 compile_count--;
878-
879- // "Progress bar" for shader compiles
880- static uint32_t total_compile_count = 0 ;
881- if ((total_compile_count++ % 10 ) == 0 ) {
882- std::cerr << " ." ;
883- }
884879 }
885880 compile_count_cond.notify_all ();
886881}
@@ -906,6 +901,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
906901static void ggml_pipeline_request_descriptor_sets (vk_device& device, vk_pipeline& pipeline, uint32_t n) {
907902 VK_LOG_DEBUG (" ggml_pipeline_request_descriptor_sets(" << pipeline->name << " , " << n << " )" );
908903 device->pipeline_descriptor_set_requirements [pipeline->name ] += n;
904+ if (!pipeline->compiled ) {
905+ pipeline->needed = true ;
906+ device->need_compiles = true ;
907+ }
909908}
910909
911910static void ggml_pipeline_allocate_descriptor_sets (vk_device& device) {
@@ -1388,8 +1387,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
13881387static void ggml_vk_load_shaders (vk_device& device) {
13891388 VK_LOG_DEBUG (" ggml_vk_load_shaders(" << device->name << " )" );
13901389
1391- std::cerr << " ggml_vulkan: Compiling shaders" ;
1392-
13931390 // some shaders have a minimum subgroup size
13941391 const uint32_t subgroup_size_16 = std::max (device->subgroup_size , 16u );
13951392 const uint32_t subgroup_size_32 = std::max (device->subgroup_size , 32u );
@@ -1527,15 +1524,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
15271524 }
15281525 }
15291526
1530- device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1531- device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1532-
1533- device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1527+ if (!device->pipeline_matmul_f32 ) {
1528+ device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1529+ }
1530+ if (!device->pipeline_matmul_f32_f16 ) {
1531+ device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1532+ }
1533+ if (!device->pipeline_matmul_id_f32 ) {
1534+ device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1535+ }
15341536
15351537 std::vector<std::future<void >> compiles;
15361538 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint,
15371539 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants,
15381540 uint32_t align, bool disable_robustness = false , bool require_full_subgroups = false , uint32_t required_subgroup_size = 0 ) {
1541+
1542+ if (!pipeline) {
1543+ pipeline = std::make_shared<vk_pipeline_struct>();
1544+ pipeline->name = name;
1545+ pipeline->parameter_count = parameter_count;
1546+ pipeline->push_constant_size = push_constant_size;
1547+ pipeline->wg_denoms = wg_denoms;
1548+ pipeline->align = align;
1549+ }
1550+
1551+ if (!pipeline->needed || pipeline->compiled ) {
1552+ return ;
1553+ }
15391554 {
15401555 // wait until fewer than N compiles are in progress
15411556 uint32_t N = std::max (1u , std::thread::hardware_concurrency ());
@@ -2050,7 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
20502065 for (auto &c : compiles) {
20512066 c.wait ();
20522067 }
2053- std::cerr << " Done! " << std::endl ;
2068+ device-> need_compiles = false ;
20542069}
20552070
20562071static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
@@ -7656,6 +7671,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
76567671 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
76577672 ggml_vk_build_graph (ctx, cgraph->nodes [i], i, nullptr , 0 , true , false , false );
76587673 }
7674+ if (ctx->device ->need_compiles ) {
7675+ ggml_vk_load_shaders (ctx->device );
7676+ }
76597677 ggml_vk_preallocate_buffers (ctx);
76607678 ggml_pipeline_allocate_descriptor_sets (ctx->device );
76617679
0 commit comments