@@ -1543,11 +1543,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
15431543 device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
15441544 }
15451545
1546+ vk::PhysicalDeviceProperties2 props2;
1547+ device->physical_device .getProperties2 (&props2);
1548+ std::string device_name = props2.properties .deviceName .data ();
15461549 std::vector<std::future<void >> compiles;
15471550 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint,
15481551 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants,
15491552 uint32_t align, bool disable_robustness = false , bool require_full_subgroups = false , uint32_t required_subgroup_size = 0 ) {
15501553
1554+ if (required_subgroup_size == 0 ) {
1555+ required_subgroup_size = (device_name.find (" RX 5700" ) != std::string::npos) ? 32 : required_subgroup_size;
1556+ }
1557+
15511558 if (!pipeline) {
15521559 pipeline = std::make_shared<vk_pipeline_struct>();
15531560 pipeline->name = name;
@@ -1573,6 +1580,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
15731580 parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
15741581 };
15751582
1583+ // New lambda for pipelines with subgroup size 64.
1584+ auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585+ const std::string &name, size_t spv_size, const void * spv_data,
1586+ const std::string &entrypoint, uint32_t parameter_count,
1587+ uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms,
1588+ const std::vector<uint32_t >& specialization_constants, uint32_t align,
1589+ bool disable_robustness = false , bool require_full_subgroups = false )
1590+ {
1591+ ggml_vk_create_pipeline (device, pipeline, name, spv_size, spv_data, entrypoint,
1592+ parameter_count, push_constant_size, wg_denoms,
1593+ specialization_constants, align, disable_robustness,
1594+ require_full_subgroups, 64 );
1595+ };
1596+
15761597#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15771598 if (device->coopmat2 ) {
15781599
@@ -2151,11 +2172,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
21512172
21522173 ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
21532174
2154- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2155- if (device->float_controls_rte_fp16 ) {
2156- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2175+ // Workaround needed to speedup im2col on RX 5700
2176+ if (device_name.find (" RX 5700" ) != std::string::npos) {
2177+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2178+ if (device->float_controls_rte_fp16 ) {
2179+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2180+ } else {
2181+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2182+ }
21572183 } else {
2158- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2184+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2185+ if (device->float_controls_rte_fp16 ) {
2186+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2187+ } else {
2188+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2189+ }
21592190 }
21602191
21612192 ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
0 commit comments