Skip to content

Commit d5192bf

Browse files
committed
enable MUL_MAT_ID mmvq support
1 parent 0297ff2 commit d5192bf

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ struct vk_device_struct {
572572
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
573573

574574
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
575+
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_COUNT];
575576

576577
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
577578
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@@ -3447,6 +3448,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
34473448
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
34483449
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
34493450

3451+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3452+
if (device->integer_dot_product) {
3453+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", mul_mat_vec_id_q4_0_q8_1_f32_len, mul_mat_vec_id_q4_0_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3454+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", mul_mat_vec_id_q4_1_q8_1_f32_len, mul_mat_vec_id_q4_1_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3455+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", mul_mat_vec_id_q5_0_q8_1_f32_len, mul_mat_vec_id_q5_0_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3456+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", mul_mat_vec_id_q5_1_q8_1_f32_len, mul_mat_vec_id_q5_1_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3457+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", mul_mat_vec_id_q8_0_q8_1_f32_len, mul_mat_vec_id_q8_0_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3458+
3459+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", mul_mat_vec_id_mxfp4_q8_1_f32_len, mul_mat_vec_id_mxfp4_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {device->subgroup_size, 2*rm_stdq_int}, 1, true);
3460+
3461+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", mul_mat_vec_id_q2_k_q8_1_f32_len, mul_mat_vec_id_q2_k_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {device->subgroup_size, 2*rm_kq_int}, 1, true);
3462+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", mul_mat_vec_id_q3_k_q8_1_f32_len, mul_mat_vec_id_q3_k_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3463+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", mul_mat_vec_id_q4_k_q8_1_f32_len, mul_mat_vec_id_q4_k_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3464+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", mul_mat_vec_id_q5_k_q8_1_f32_len, mul_mat_vec_id_q5_k_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3465+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", mul_mat_vec_id_q6_k_q8_1_f32_len, mul_mat_vec_id_q6_k_q8_1_f32_data, "main", 4, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3466+
}
3467+
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3468+
34503469
// dequant shaders
34513470
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
34523471
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -5303,6 +5322,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
53035322

53045323
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
53055324
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
5325+
5326+
if (b_type == GGML_TYPE_Q8_1) {
5327+
switch (a_type) {
5328+
case GGML_TYPE_Q4_0:
5329+
case GGML_TYPE_Q4_1:
5330+
case GGML_TYPE_Q5_0:
5331+
case GGML_TYPE_Q5_1:
5332+
case GGML_TYPE_Q8_0:
5333+
case GGML_TYPE_MXFP4:
5334+
case GGML_TYPE_Q2_K:
5335+
case GGML_TYPE_Q3_K:
5336+
case GGML_TYPE_Q4_K:
5337+
case GGML_TYPE_Q5_K:
5338+
case GGML_TYPE_Q6_K:
5339+
break;
5340+
default:
5341+
return nullptr;
5342+
}
5343+
5344+
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[a_type];
5345+
}
5346+
53065347
GGML_ASSERT(b_type == GGML_TYPE_F32);
53075348

53085349
switch (a_type) {
@@ -6483,6 +6524,11 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
64836524
return false;
64846525
}
64856526

6527+
// General issue with q3_k and q6_k
6528+
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
6529+
return false;
6530+
}
6531+
64866532
// MMVQ is generally good for batches
64876533
if (n > 1) {
64886534
return true;
@@ -6492,6 +6538,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
64926538
case VK_VENDOR_ID_NVIDIA:
64936539
switch (src0_type) {
64946540
case GGML_TYPE_Q8_0:
6541+
case GGML_TYPE_MXFP4:
64956542
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
64966543
default:
64976544
return true;
@@ -7329,12 +7376,41 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
73297376
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
73307377

73317378
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7379+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
7380+
7381+
vk_pipeline to_fp16_vk_0 = nullptr;
7382+
vk_pipeline to_fp16_vk_1 = nullptr;
7383+
if (x_non_contig) {
7384+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
7385+
}
7386+
if (y_non_contig) {
7387+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
7388+
} else {
7389+
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7390+
}
7391+
7392+
// Check for mmq first
7393+
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1) : nullptr;
7394+
vk_pipeline to_q8_1 = nullptr;
7395+
7396+
if (dmmv == nullptr) {
7397+
// Fall back to f16 dequant mul mat
7398+
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
7399+
quantize_y = false;
7400+
}
7401+
7402+
if (quantize_y) {
7403+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
7404+
}
73327405

73337406
const bool qx_needs_dequant = x_non_contig;
7334-
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
7407+
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
73357408

73367409
// Not implemented
73377410
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7411+
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7412+
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7413+
GGML_ASSERT(dmmv != nullptr);
73387414

73397415
const uint64_t x_ne = ne01 * ne00;
73407416
const uint64_t y_ne = ne11 * ne10;
@@ -7343,28 +7419,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
73437419
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
73447420
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
73457421
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
7346-
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
7422+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
73477423
const uint64_t ids_sz = nbi2;
73487424
const uint64_t d_sz = sizeof(float) * d_ne;
73497425

7350-
vk_pipeline to_fp16_vk_0 = nullptr;
7351-
vk_pipeline to_fp16_vk_1 = nullptr;
7352-
if (x_non_contig) {
7353-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
7354-
}
7355-
if (y_non_contig) {
7356-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
7357-
} else {
7358-
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7359-
}
7360-
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
7361-
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7362-
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7363-
GGML_ASSERT(dmmv != nullptr);
7364-
73657426
if (dryrun) {
73667427
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
7367-
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
7428+
uint64_t y_sz_upd = y_sz * ne12 * ne13;
7429+
if (quantize_y) {
7430+
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
7431+
}
73687432
if (
73697433
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
73707434
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -7373,7 +7437,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
73737437
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
73747438
ctx->prealloc_size_x = x_sz_upd;
73757439
}
7376-
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
7440+
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
73777441
ctx->prealloc_size_y = y_sz_upd;
73787442
}
73797443

@@ -7384,6 +7448,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
73847448
if (qy_needs_dequant) {
73857449
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
73867450
}
7451+
if (quantize_y) {
7452+
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7453+
}
73877454
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
73887455
return;
73897456
}
@@ -7419,6 +7486,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
74197486
}
74207487
if (qy_needs_dequant) {
74217488
d_Y = ctx->prealloc_y;
7489+
} else if (quantize_y) {
7490+
d_Y = ctx->prealloc_y;
7491+
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
74227492
} else {
74237493
d_Y = d_Qy;
74247494
y_buf_offset = qy_buf_offset;
@@ -7447,6 +7517,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
74477517
ctx->prealloc_y_last_tensor_used = src1;
74487518
}
74497519
}
7520+
if (quantize_y) {
7521+
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7522+
ctx->prealloc_y_last_tensor_used != src1) {
7523+
if (ctx->prealloc_y_need_sync) {
7524+
ggml_vk_sync_buffers(ctx, subctx);
7525+
}
7526+
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
7527+
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7528+
ctx->prealloc_y_last_tensor_used = src1;
7529+
}
7530+
}
74507531

74517532
uint32_t stride_batch_y = ne10*ne11;
74527533

@@ -7464,16 +7545,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
74647545
groups_x = CEIL_DIV(groups_x, groups_z);
74657546
}
74667547

7548+
uint32_t y_sz_total = y_sz * ne12 * ne13;
7549+
if (quantize_y) {
7550+
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
7551+
}
7552+
74677553
// compute
74687554
const vk_mat_vec_id_push_constants pc = {
74697555
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
74707556
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
74717557
(uint32_t)nei0, (uint32_t)ne11,
74727558
};
7473-
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
7474-
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
7475-
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
7476-
pc, { groups_x, (uint32_t)nei0, groups_z });
7559+
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
7560+
vk_subbuffer{ d_Y, y_buf_offset, y_sz_total },
7561+
vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23},
7562+
vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
7563+
pc, { groups_x, (uint32_t)nei0, groups_z });
74777564

74787565
if (x_non_contig) {
74797566
ctx->prealloc_x_need_sync = true;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,8 @@ void process_shaders() {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
671+
672+
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
671673
}
672674
#endif
673675

0 commit comments

Comments
 (0)