@@ -149,6 +149,7 @@ class vk_perf_logger;
149149static void ggml_vk_destroy_buffer (vk_buffer& buf);
150150
151151static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
152+ static constexpr uint32_t p021_max_gqa_ratio = 8 ;
152153
153154enum vk_device_architecture {
154155 OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231232 bool uma;
232233 bool prefer_host_memory;
233234 bool float_controls_rte_fp16;
235+ bool subgroup_add;
234236
235237 bool subgroup_size_control;
236238 uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277279 vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278280 vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279281
280- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio] ;
281283 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282284 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283285 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
22652267
22662268 ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
22672269
2268- ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270+ for (uint32_t i = 0 ; i < p021_max_gqa_ratio; ++i) {
2271+ if (device->subgroup_add && device->subgroup_require_full_support ) {
2272+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true , true );
2273+ } else {
2274+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true );
2275+ }
2276+ }
22692277 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
22702278
22712279 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -2471,13 +2479,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
24712479 vk::PhysicalDeviceDriverProperties driver_props;
24722480 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
24732481 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2482+ vk::PhysicalDeviceVulkan11Properties vk11_props;
24742483 vk::PhysicalDeviceVulkan12Properties vk12_props;
24752484 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
24762485
24772486 props2.pNext = &props3;
24782487 props3.pNext = &subgroup_props;
24792488 subgroup_props.pNext = &driver_props;
2480- driver_props.pNext = &vk12_props;
2489+ driver_props.pNext = &vk11_props;
2490+ vk11_props.pNext = &vk12_props;
24812491
24822492 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
24832493
@@ -2541,6 +2551,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
25412551 }
25422552 device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16 ;
25432553
2554+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2555+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2556+
25442557 const bool force_disable_f16 = getenv (" GGML_VK_DISABLE_F16" ) != nullptr ;
25452558
25462559 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4627,9 +4640,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46274640 const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
46284641 const uint64_t d_sz = sizeof (float ) * d_ne;
46294642
4643+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4644+ uint32_t gqa_ratio = (uint32_t )ne12 / (uint32_t )ne02;
4645+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4646+ gqa_ratio = 1 ;
4647+ }
4648+
46304649 if (dryrun) {
46314650 // Request descriptor sets
4632- ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , 1 );
4651+ ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , 1 );
46334652 return ;
46344653 }
46354654
@@ -4653,8 +4672,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46534672
46544673 // compute
46554674 const std::array<uint32_t , 6 > pc = { (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )) };
4675+
4676+ uint32_t workgroups_z = (uint32_t )ne12;
4677+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4678+ if (gqa_ratio > 1 ) {
4679+ workgroups_z /= gqa_ratio;
4680+ }
4681+
46564682 ggml_vk_sync_buffers (subctx);
4657- ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, ( uint32_t )ne12 });
4683+ ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, workgroups_z });
46584684}
46594685
46604686static void ggml_vk_mul_mat_vec_nc_f16_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
0 commit comments