@@ -452,6 +452,7 @@ struct vk_device_struct {
452452
453453 vk_pipeline pipeline_matmul_split_k_reduce;
454454 vk_pipeline pipeline_quantize_q8_1;
455+ vk_pipeline pipeline_quantize_q8_1_x4;
455456
456457 vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
457458 vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -3005,8 +3006,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
30053006
30063007 if (device->subgroup_clustered && device->subgroup_require_full_support) {
30073008 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
3009+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
30083010 } else {
30093011 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
3012+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
30103013 }
30113014
30123015 for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5548,20 +5551,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
55485551 ggml_vk_sync_buffers(ctx, subctx);
55495552}
55505553
5551- static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
5554+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks ) {
55525555 switch(type) {
55535556 case GGML_TYPE_Q8_1:
5554- return ctx->device->pipeline_quantize_q8_1;
5557+ return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
55555558 default:
55565559 std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
55575560 GGML_ABORT("fatal error");
55585561 }
55595562}
55605563
5561- static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
5564+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false ) {
55625565 VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
55635566
5564- vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5567+ vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false );
55655568
55665569 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
55675570 ggml_vk_sync_buffers(ctx, subctx);
@@ -5681,7 +5684,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56815684 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
56825685
56835686 if (quantize_y) {
5684- to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5687+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false );
56855688 }
56865689
56875690 if (dryrun) {
@@ -5877,16 +5880,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58775880 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
58785881
58795882 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
5880-
5881- const uint64_t x_ne = ne01 * ne00;
5882- const uint64_t y_ne = ne11 * ne10;
5883- const uint64_t d_ne = ne11 * ne01;
5884-
5885- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5886- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5887- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5888- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
5889- const uint64_t d_sz = sizeof(float) * d_ne;
5883+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
58905884
58915885 vk_pipeline to_fp16_vk_0 = nullptr;
58925886 vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5899,8 +5893,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58995893 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
59005894 }
59015895
5902- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5903-
59045896 // Check for mmq first
59055897 vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
59065898 vk_pipeline to_q8_1 = nullptr;
@@ -5912,7 +5904,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59125904 }
59135905
59145906 if (quantize_y) {
5915- to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5907+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true );
59165908 }
59175909
59185910 const bool qx_needs_dequant = x_non_contig;
@@ -5925,6 +5917,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59255917 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
59265918 GGML_ASSERT(dmmv != nullptr);
59275919
5920+ const uint64_t x_ne = ne01 * ne00;
5921+ const uint64_t y_ne = ne11 * ne10;
5922+ const uint64_t d_ne = ne11 * ne01;
5923+
5924+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5925+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5926+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5927+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
5928+ const uint64_t d_sz = sizeof(float) * d_ne;
5929+
59285930 if (dryrun) {
59295931 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
59305932 const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -5937,7 +5939,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59375939 ctx->prealloc_size_x = x_sz_upd;
59385940 }
59395941 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5940- ctx->prealloc_size_y = y_sz_upd;
5942+ ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
59415943 }
59425944
59435945 // Request descriptor sets
@@ -5982,7 +5984,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59825984 d_Y = ctx->prealloc_y;
59835985 } else if (quantize_y) {
59845986 d_Y = ctx->prealloc_y;
5985- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1) );
5987+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128 );
59865988 } else {
59875989 d_Y = d_Qy;
59885990 y_buf_offset = qy_buf_offset;
@@ -6014,7 +6016,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60146016 }
60156017 }
60166018 if (quantize_y) {
6017- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
6019+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true );
60186020 }
60196021
60206022 // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
0 commit comments