@@ -149,6 +149,7 @@ class vk_perf_logger;
149149static  void  ggml_vk_destroy_buffer (vk_buffer& buf);
150150
151151static  constexpr  uint32_t  mul_mat_vec_max_cols = 8 ;
152+ static  constexpr  uint32_t  p021_max_gqa_ratio = 8 ;
152153
153154enum  vk_device_architecture {
154155    OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231232    bool  uma;
232233    bool  prefer_host_memory;
233234    bool  float_controls_rte_fp16;
235+     bool  subgroup_add;
234236
235237    bool  subgroup_size_control;
236238    uint32_t  subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277279    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278280    vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279281
280-     vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282+     vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio] ;
281283    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282284    vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283285    vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
22652267
22662268    ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce"  , split_k_reduce_len, split_k_reduce_data, " main"  , 2 , 2  * sizeof (uint32_t ), {256  * 4 , 1 , 1 }, {}, 1 );
22672269
2268-     ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32"  , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main"  , 3 , 6  * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270+     for  (uint32_t  i = 0 ; i < p021_max_gqa_ratio; ++i) {
2271+         if  (device->subgroup_add  && device->subgroup_require_full_support ) {
2272+             ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32"  +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, " main"  , 3 , 6  * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true , true );
2273+         } else  {
2274+             ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32"  +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              " main"  , 3 , 6  * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true );
2275+         }
2276+     }
22692277    ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32"  , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main"  , 3 , 7  * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
22702278
22712279    ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32"  , norm_f32_len, norm_f32_data, " main"  , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -2479,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
24792487        vk::PhysicalDeviceDriverProperties driver_props;
24802488        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
24812489        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2490+         vk::PhysicalDeviceVulkan11Properties vk11_props;
24822491        vk::PhysicalDeviceVulkan12Properties vk12_props;
24832492        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
24842493
24852494        props2.pNext  = &props3;
24862495        props3.pNext  = &subgroup_props;
24872496        subgroup_props.pNext  = &driver_props;
2488-         driver_props.pNext  = &vk12_props;
2497+         driver_props.pNext  = &vk11_props;
2498+         vk11_props.pNext  = &vk12_props;
24892499
24902500        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
24912501
@@ -2549,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
25492559        }
25502560        device->float_controls_rte_fp16  = vk12_props.shaderRoundingModeRTEFloat16 ;
25512561
2562+         device->subgroup_add  = (vk11_props.subgroupSupportedStages  & vk::ShaderStageFlagBits::eCompute) &&
2563+                                (vk11_props.subgroupSupportedOperations  & vk::SubgroupFeatureFlagBits::eArithmetic);
2564+ 
25522565        const  bool  force_disable_f16 = getenv (" GGML_VK_DISABLE_F16"  ) != nullptr ;
25532566
25542567        device->fp16  = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4635,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46354648    const  uint64_t  qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
46364649    const  uint64_t  d_sz = sizeof (float ) * d_ne;
46374650
4651+     //  With grouped query attention there are > 1 Q matrices per K, V matrix.
4652+     uint32_t  gqa_ratio = (uint32_t )ne12 / (uint32_t )ne02;
4653+     if  (gqa_ratio > 8  || gqa_ratio == 0  || ne12 != ne02 * gqa_ratio) {
4654+         gqa_ratio = 1 ;
4655+     }
4656+ 
46384657    if  (dryrun) {
46394658        //  Request descriptor sets
4640-         ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , 1 );
4659+         ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio -  1 ] , 1 );
46414660        return ;
46424661    }
46434662
@@ -4661,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46614680
46624681    //  compute
46634682    const  std::array<uint32_t , 6 > pc = { (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )) };
4683+ 
4684+     uint32_t  workgroups_z = (uint32_t )ne12;
4685+     //  When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4686+     if  (gqa_ratio > 1 ) {
4687+         workgroups_z /= gqa_ratio;
4688+     }
4689+ 
46644690    ggml_vk_sync_buffers (subctx);
4665-     ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6  * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, ( uint32_t )ne12  });
4691+     ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio -  1 ] , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6  * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, workgroups_z  });
46664692}
46674693
46684694static  void  ggml_vk_mul_mat_vec_nc_f16_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst, bool  dryrun = false ) {
0 commit comments