Skip to content

Commit fd8be28

Browse files
committed
vulkan: Add Integer Dot Product mul_mat_vec shader for legacy quants
1 parent ca0ef2d commit fd8be28

File tree

4 files changed

+310
-14
lines changed

4 files changed

+310
-14
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ struct vk_device_struct {
344344
bool float_controls_rte_fp16;
345345
bool subgroup_add;
346346
bool subgroup_shuffle;
347+
bool subgroup_clustered;
347348

348349
bool integer_dot_product;
349350

@@ -409,6 +410,8 @@ struct vk_device_struct {
409410
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
410411
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
411412

413+
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
414+
412415
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
413416
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
414417
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
@@ -2752,6 +2755,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
27522755
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
27532756
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
27542757
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2758+
2759+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2760+
if (device->integer_dot_product) {
2761+
if (device->subgroup_clustered && device->vendor_id != VK_VENDOR_ID_INTEL) {
2762+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_q8_1_f32_subgroup_len, mul_mat_vec_q4_0_q8_1_f32_subgroup_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true, true);
2763+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_q8_1_f32_subgroup_len, mul_mat_vec_q4_1_q8_1_f32_subgroup_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true, true);
2764+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_q8_1_f32_subgroup_len, mul_mat_vec_q5_0_q8_1_f32_subgroup_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true, true);
2765+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_q8_1_f32_subgroup_len, mul_mat_vec_q5_1_q8_1_f32_subgroup_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true, true);
2766+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_subgroup_len, mul_mat_vec_q8_0_q8_1_f32_subgroup_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true, true);
2767+
} else {
2768+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_q8_1_f32_len, mul_mat_vec_q4_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2769+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_q8_1_f32_len, mul_mat_vec_q4_1_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2770+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_q8_1_f32_len, mul_mat_vec_q5_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2771+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_q8_1_f32_len, mul_mat_vec_q5_1_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2772+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_len, mul_mat_vec_q8_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
2773+
}
2774+
}
2775+
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
27552776
}
27562777

27572778
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -3275,9 +3296,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
32753296

32763297
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
32773298
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
3278-
32793299
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
32803300
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
3301+
device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
3302+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
32813303

32823304
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
32833305

@@ -4236,9 +4258,22 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
42364258

42374259
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
42384260
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
4239-
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
4261+
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
42404262
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
42414263

4264+
if (b_type == GGML_TYPE_Q8_1) {
4265+
switch (a_type) {
4266+
case GGML_TYPE_Q4_0:
4267+
case GGML_TYPE_Q4_1:
4268+
case GGML_TYPE_Q5_0:
4269+
case GGML_TYPE_Q5_1:
4270+
case GGML_TYPE_Q8_0:
4271+
return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[a_type][num_cols-1];
4272+
default:
4273+
return nullptr;
4274+
}
4275+
}
4276+
42424277
switch (a_type) {
42434278
case GGML_TYPE_F32:
42444279
case GGML_TYPE_F16:
@@ -4325,7 +4360,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
43254360
}
43264361

43274362
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
4328-
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
4363+
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
43294364
GGML_ASSERT(b_type == GGML_TYPE_F32);
43304365

43314366
switch (a_type) {
@@ -5507,12 +5542,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
55075542

55085543
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
55095544

5510-
const bool qx_needs_dequant = x_non_contig;
5511-
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
5512-
5513-
// Not implemented
5514-
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
5515-
55165545
const uint64_t x_ne = ne01 * ne00;
55175546
const uint64_t y_ne = ne11 * ne10;
55185547
const uint64_t d_ne = ne11 * ne01;
@@ -5533,7 +5562,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
55335562
} else {
55345563
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
55355564
}
5536-
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
5565+
5566+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5567+
5568+
// Check for mmq first
5569+
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11) : nullptr;
5570+
vk_pipeline to_q8_1 = nullptr;
5571+
5572+
if (dmmv == nullptr) {
5573+
// Fall back to f16 dequant mul mat
5574+
dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
5575+
quantize_y = false;
5576+
}
5577+
5578+
if (quantize_y) {
5579+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5580+
}
5581+
5582+
const bool qx_needs_dequant = x_non_contig;
5583+
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
5584+
5585+
// Not implemented
5586+
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
5587+
55375588
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
55385589
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
55395590
GGML_ASSERT(dmmv != nullptr);
@@ -5549,7 +5600,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
55495600
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
55505601
ctx->prealloc_size_x = x_sz_upd;
55515602
}
5552-
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
5603+
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
55535604
ctx->prealloc_size_y = y_sz_upd;
55545605
}
55555606

@@ -5560,6 +5611,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
55605611
if (qy_needs_dequant) {
55615612
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
55625613
}
5614+
if (quantize_y) {
5615+
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
5616+
}
55635617
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
55645618
return;
55655619
}
@@ -5590,6 +5644,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
55905644
}
55915645
if (qy_needs_dequant) {
55925646
d_Y = ctx->prealloc_y;
5647+
} else if (quantize_y) {
5648+
d_Y = ctx->prealloc_y;
5649+
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
55935650
} else {
55945651
d_Y = d_Qy;
55955652
y_buf_offset = qy_buf_offset;
@@ -5604,6 +5661,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56045661
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
56055662
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
56065663
}
5664+
if (quantize_y) {
5665+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5666+
}
56075667

56085668
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
56095669
uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
@@ -11285,7 +11345,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1128511345
} else if (tensor->op == GGML_OP_CONCAT) {
1128611346
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
1128711347
} else if (tensor->op == GGML_OP_UPSCALE) {
11288-
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
11348+
tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
1128911349
} else if (tensor->op == GGML_OP_SCALE) {
1129011350
const float * params = (const float *)tensor->op_params;
1129111351
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
@@ -11400,7 +11460,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1140011460
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
1140111461
}
1140211462
} else if (tensor->op == GGML_OP_SET_ROWS) {
11403-
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
11463+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
1140411464
} else if (tensor->op == GGML_OP_CONT) {
1140511465
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1140611466
} else if (tensor->op == GGML_OP_RESHAPE) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88

99
#include "types.comp"
1010

11+
#ifndef MMQ
1112
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13+
#else
14+
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
15+
#endif
16+
1217
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
18+
#ifdef B_TYPE_VEC2
1319
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
20+
#endif
21+
#ifdef B_TYPE_VEC4
1422
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
23+
#endif
1524

1625
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
1726
#ifdef MUL_MAT_ID
@@ -88,6 +97,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
8897
layout (constant_id = 1) const uint NUM_ROWS = 1;
8998
layout (constant_id = 2) const uint NUM_COLS = 1;
9099

100+
#if !defined(MMQ) || !defined(USE_SUBGROUPS)
91101
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
92102

93103
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
@@ -116,3 +126,4 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
116126
}
117127
}
118128
}
129+
#endif

0 commit comments

Comments
 (0)