@@ -428,6 +428,7 @@ struct vk_device_struct {
428428
429429 vk_pipeline pipeline_matmul_split_k_reduce;
430430 vk_pipeline pipeline_quantize_q8_1;
431+ vk_pipeline pipeline_quantize_q8_1_x4;
431432
432433 vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
433434 vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -2900,8 +2901,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
29002901
29012902 if (device->subgroup_clustered && device->subgroup_require_full_support) {
29022903 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
2904+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
29032905 } else {
29042906 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2907+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
29052908 }
29062909
29072910 for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5352,20 +5355,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
53525355 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
53535356}
53545357
5355- static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
5358+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks ) {
53565359 switch(type) {
53575360 case GGML_TYPE_Q8_1:
5358- return ctx->device->pipeline_quantize_q8_1;
5361+ return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
53595362 default:
53605363 std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
53615364 GGML_ABORT("fatal error");
53625365 }
53635366}
53645367
5365- static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
5368+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false ) {
53665369 VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
53675370
5368- vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5371+ vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false );
53695372
53705373 ggml_vk_sync_buffers(subctx);
53715374 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
@@ -5485,7 +5488,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
54855488 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
54865489
54875490 if (quantize_y) {
5488- to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5491+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false );
54895492 }
54905493
54915494 if (dryrun) {
@@ -5653,16 +5656,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56535656 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
56545657
56555658 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
5656-
5657- const uint64_t x_ne = ne01 * ne00;
5658- const uint64_t y_ne = ne11 * ne10;
5659- const uint64_t d_ne = ne11 * ne01;
5660-
5661- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5662- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5663- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5664- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
5665- const uint64_t d_sz = sizeof(float) * d_ne;
5659+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
56665660
56675661 vk_pipeline to_fp16_vk_0 = nullptr;
56685662 vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5675,8 +5669,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56755669 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
56765670 }
56775671
5678- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5679-
56805672 // Check for mmq first
56815673 vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11) : nullptr;
56825674 vk_pipeline to_q8_1 = nullptr;
@@ -5688,7 +5680,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
56885680 }
56895681
56905682 if (quantize_y) {
5691- to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5683+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true );
56925684 }
56935685
56945686 const bool qx_needs_dequant = x_non_contig;
@@ -5701,6 +5693,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57015693 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
57025694 GGML_ASSERT(dmmv != nullptr);
57035695
5696+ const uint64_t x_ne = ne01 * ne00;
5697+ const uint64_t y_ne = ne11 * ne10;
5698+ const uint64_t d_ne = ne11 * ne01;
5699+
5700+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5701+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5702+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5703+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
5704+ const uint64_t d_sz = sizeof(float) * d_ne;
5705+
57045706 if (dryrun) {
57055707 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
57065708 const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -5713,7 +5715,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57135715 ctx->prealloc_size_x = x_sz_upd;
57145716 }
57155717 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5716- ctx->prealloc_size_y = y_sz_upd;
5718+ ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
57175719 }
57185720
57195721 // Request descriptor sets
@@ -5758,7 +5760,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57585760 d_Y = ctx->prealloc_y;
57595761 } else if (quantize_y) {
57605762 d_Y = ctx->prealloc_y;
5761- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1) );
5763+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128 );
57625764 } else {
57635765 d_Y = d_Qy;
57645766 y_buf_offset = qy_buf_offset;
@@ -5774,7 +5776,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57745776 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
57755777 }
57765778 if (quantize_y) {
5777- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5779+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true );
57785780 }
57795781
57805782 // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
0 commit comments