@@ -1423,6 +1423,36 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
14231423 return supported;
14241424}
14251425
1426+ // Define a configuration map per GPU.
1427+ // Outer key: GPU identifier (e.g. "RX 5700").
1428+ // Inner map: key is pipeline name; value is the subgroup size.
1429+ static std::unordered_map<std::string, std::unordered_map<std::string, uint32_t >> gpu_pipeline_config = {
1430+ {" RX 5700" , {
1431+ {" im2col_f32" , 64 },
1432+ {" im2col_f32_f16" , 64 }
1433+ }}
1434+ };
1435+
1436+ // Helper function defined at namespace scope.
1437+ static uint32_t get_subgroup_size (const std::string &pipeline_name, const std::string &device_name) {
1438+ std::string foundKey;
1439+ for (const auto &entry : gpu_pipeline_config) {
1440+ if (device_name.find (entry.first ) != std::string::npos) {
1441+ foundKey = entry.first ;
1442+ break ;
1443+ }
1444+ }
1445+ if (!foundKey.empty ()) {
1446+ auto &pipelineMap = gpu_pipeline_config[foundKey];
1447+ auto pipIt = pipelineMap.find (pipeline_name);
1448+ if (pipIt != pipelineMap.end () && pipIt->second != 0 ) {
1449+ return pipIt->second ;
1450+ }
1451+ }
1452+ // If not defined, return 0.
1453+ return 0 ;
1454+ }
1455+
14261456static void ggml_vk_load_shaders (vk_device& device) {
14271457 VK_LOG_DEBUG (" ggml_vk_load_shaders(" << device->name << " )" );
14281458
@@ -1546,11 +1576,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
15461576 vk::PhysicalDeviceProperties2 props2;
15471577 device->physical_device .getProperties2 (&props2);
15481578 std::string device_name = props2.properties .deviceName .data ();
1579+
15491580 std::vector<std::future<void >> compiles;
15501581 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint,
15511582 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants,
15521583 uint32_t align, bool disable_robustness = false , bool require_full_subgroups = false , uint32_t required_subgroup_size = 0 ) {
15531584
1585+ required_subgroup_size = get_subgroup_size (name, device_name);
15541586 if (required_subgroup_size == 0 ) {
15551587 required_subgroup_size = (device_name.find (" RX 5700" ) != std::string::npos) ? 32 : required_subgroup_size;
15561588 }
@@ -1580,20 +1612,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
15801612 parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
15811613 };
15821614
1583- // New lambda for pipelines with subgroup size 64.
1584- auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585- const std::string &name, size_t spv_size, const void * spv_data,
1586- const std::string &entrypoint, uint32_t parameter_count,
1587- uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms,
1588- const std::vector<uint32_t >& specialization_constants, uint32_t align,
1589- bool disable_robustness = false , bool require_full_subgroups = false )
1590- {
1591- ggml_vk_create_pipeline (device, pipeline, name, spv_size, spv_data, entrypoint,
1592- parameter_count, push_constant_size, wg_denoms,
1593- specialization_constants, align, disable_robustness,
1594- require_full_subgroups, 64 );
1595- };
1596-
15971615#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15981616 if (device->coopmat2 ) {
15991617
@@ -2174,11 +2192,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
21742192
21752193 // Workaround needed to speedup im2col on RX 5700
21762194 if (device_name.find (" RX 5700" ) != std::string::npos) {
2177- ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2195+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
21782196 if (device->float_controls_rte_fp16 ) {
2179- ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2197+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
21802198 } else {
2181- ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2199+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
21822200 }
21832201 } else {
21842202 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
0 commit comments