@@ -149,6 +149,7 @@ class vk_perf_logger;
149
149
static void ggml_vk_destroy_buffer (vk_buffer& buf);
150
150
151
151
static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
152
+ static constexpr uint32_t p021_max_gqa_ratio = 8 ;
152
153
153
154
enum vk_device_architecture {
154
155
OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231
232
bool uma;
232
233
bool prefer_host_memory;
233
234
bool float_controls_rte_fp16;
235
+ bool subgroup_add;
234
236
235
237
bool subgroup_size_control;
236
238
uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277
279
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278
280
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279
281
280
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio] ;
281
283
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282
284
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283
285
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2265
2267
2266
2268
ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
2267
2269
2268
- ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270
+ for (uint32_t i = 0 ; i < p021_max_gqa_ratio; ++i) {
2271
+ if (device->subgroup_add && device->subgroup_require_full_support ) {
2272
+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true , true );
2273
+ } else {
2274
+ ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 [i], " mul_mat_vec_p021_f16_f32" +std::to_string (i+1 ), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {device->subgroup_size , i + 1 }, 1 , true );
2275
+ }
2276
+ }
2269
2277
ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270
2278
2271
2279
ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -2479,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2479
2487
vk::PhysicalDeviceDriverProperties driver_props;
2480
2488
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2481
2489
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2490
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
2482
2491
vk::PhysicalDeviceVulkan12Properties vk12_props;
2483
2492
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2484
2493
2485
2494
props2.pNext = &props3;
2486
2495
props3.pNext = &subgroup_props;
2487
2496
subgroup_props.pNext = &driver_props;
2488
- driver_props.pNext = &vk12_props;
2497
+ driver_props.pNext = &vk11_props;
2498
+ vk11_props.pNext = &vk12_props;
2489
2499
2490
2500
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2491
2501
@@ -2549,6 +2559,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2549
2559
}
2550
2560
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16 ;
2551
2561
2562
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2563
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2564
+
2552
2565
const bool force_disable_f16 = getenv (" GGML_VK_DISABLE_F16" ) != nullptr ;
2553
2566
2554
2567
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4635,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4635
4648
const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
4636
4649
const uint64_t d_sz = sizeof (float ) * d_ne;
4637
4650
4651
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4652
+ uint32_t gqa_ratio = (uint32_t )ne12 / (uint32_t )ne02;
4653
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4654
+ gqa_ratio = 1 ;
4655
+ }
4656
+
4638
4657
if (dryrun) {
4639
4658
// Request descriptor sets
4640
- ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , 1 );
4659
+ ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , 1 );
4641
4660
return ;
4642
4661
}
4643
4662
@@ -4661,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4661
4680
4662
4681
// compute
4663
4682
const std::array<uint32_t , 6 > pc = { (uint32_t )ne00, (uint32_t )ne01, (uint32_t )ne02, (uint32_t )ne12, (uint32_t )(qy_shader_offset / ggml_type_size (src1->type )), (uint32_t )(d_shader_offset / ggml_type_size (dst->type )) };
4683
+
4684
+ uint32_t workgroups_z = (uint32_t )ne12;
4685
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4686
+ if (gqa_ratio > 1 ) {
4687
+ workgroups_z /= gqa_ratio;
4688
+ }
4689
+
4664
4690
ggml_vk_sync_buffers (subctx);
4665
- ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, ( uint32_t )ne12 });
4691
+ ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 [gqa_ratio - 1 ] , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, workgroups_z });
4666
4692
}
4667
4693
4668
4694
static void ggml_vk_mul_mat_vec_nc_f16_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
0 commit comments