@@ -149,6 +149,7 @@ class vk_perf_logger;
149149static void ggml_vk_destroy_buffer (vk_buffer& buf);
150150
151151static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
152+ static constexpr uint32_t p021_max_gqa_ratio = 8 ;
152153
153154enum vk_device_architecture {
154155 OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231232 bool uma;
232233 bool prefer_host_memory;
233234 bool float_controls_rte_fp16;
235+ bool subgroup_add;
234236
235237 bool subgroup_size_control;
236238 uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277279 vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278280 vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279281
280- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio] ;
281283 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282284 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283285 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
22652267
22662268 ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
22672269
2268- ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270+ for (uint32_t i = 0 ; i < p021_max_gqa_ratio; ++i) {
2271+ if (device->subgroup_add && device->subgroup_require_full_support ) {
2272+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true , true );
2273+ } else {
2274+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true );
2275+ }
2276+ }
22692277 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
22702278
22712279 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -2479,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
24792487 vk::PhysicalDeviceDriverProperties driver_props;
24802488 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
24812489 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2490+ vk::PhysicalDeviceVulkan11Properties vk11_props;
24822491 vk::PhysicalDeviceVulkan12Properties vk12_props;
24832492 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
24842493
24852494 props2.pNext = &props3;
24862495 props3.pNext = &subgroup_props;
24872496 subgroup_props.pNext = &driver_props;
2488- driver_props.pNext = &vk12_props;
2497+ driver_props.pNext = &vk11_props;
2498+ vk11_props.pNext = &vk12_props;
24892499
24902500 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
24912501
@@ -2549,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
25492559 }
25502560 device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16 ;
25512561
2562+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2563+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2564+
25522565 const bool force_disable_f16 = getenv (" GGML_VK_DISABLE_F16" ) != nullptr ;
25532566
25542567 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4635,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46354648 const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
46364649 const uint64_t d_sz = sizeof (float ) * d_ne;
46374650
4651+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4652+ uint32_t gqa_ratio = (uint32_t )ne12 / (uint32_t )ne02;
4653+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4654+ gqa_ratio = 1 ;
4655+ }
4656+
46384657 if (dryrun) {
46394658 // Request descriptor sets
4640- ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , 1 );
4659+ ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , 1 );
46414660 return ;
46424661 }
46434662
@@ -4661,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46614680
46624681 // compute
46634682 const std::array<uint32_t , 6 > pc = { (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )) };
4683+
4684+ uint32_t workgroups_z = (uint32_t )ne12;
4685+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4686+ if (gqa_ratio > 1 ) {
4687+ workgroups_z /= gqa_ratio;
4688+ }
4689+
46644690 ggml_vk_sync_buffers (subctx);
4665- ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, ( uint32_t )ne12 });
4691+ ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, workgroups_z });
46664692}
46674693
46684694static void ggml_vk_mul_mat_vec_nc_f16_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
0 commit comments