@@ -2787,7 +2787,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27872787 uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
27882788 uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
27892789
2790- const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
2790+ const bool s = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
27912791
27922792 for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
27932793 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
@@ -2838,8 +2838,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
28382838 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
28392839 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
28402840 }
2841+ }
28412842
28422843#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2844+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
28432845 if (device->integer_dot_product) {
28442846 const uint32_t subgroup_size = (device->subgroup_size_control && device->vendor_id == VK_VENDOR_ID_INTEL) ? device->subgroup_min_size : device->subgroup_size;
28452847 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
@@ -2856,8 +2858,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
28562858 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_len, mul_mat_vec_q8_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {subgroup_size, 1*rm_stdq, i+1}, 1, true);
28572859 }
28582860 }
2859- #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
28602861 }
2862+ #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
28612863
28622864 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
28632865 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -5614,7 +5616,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56145616
56155617 if (dryrun) {
56165618 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
5617- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
5619+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
5620+ if (quantize_y) {
5621+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
5622+ }
56185623 const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
56195624 if (
56205625 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -5626,7 +5631,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56265631 ctx->prealloc_size_x = x_sz_upd;
56275632 }
56285633 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5629- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
5634+ ctx->prealloc_size_y = y_sz_upd;
56305635 }
56315636 if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
56325637 ctx->prealloc_size_split_k = split_k_size;
@@ -5680,7 +5685,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56805685 GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
56815686 } else if (quantize_y) {
56825687 d_Y = ctx->prealloc_y;
5683- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1) );
5688+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144 );
56845689 } else {
56855690 d_Y = d_Qy;
56865691 y_buf_offset = qy_buf_offset;
@@ -5712,10 +5717,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57125717 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
57135718 }
57145719
5720+ uint32_t y_sz_total = y_sz * ne12 * ne13;
5721+ if (quantize_y) {
5722+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
5723+ }
5724+
57155725 // compute
57165726 ggml_vk_matmul(
57175727 ctx, subctx, pipeline,
5718- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
5728+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
57195729 { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
57205730 ne01, ne11, ne10,
57215731 ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
@@ -5826,7 +5836,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58265836
58275837 if (dryrun) {
58285838 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
5829- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
5839+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
5840+ if (quantize_y) {
5841+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
5842+ }
58305843 if (
58315844 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
58325845 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
@@ -5836,7 +5849,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58365849 ctx->prealloc_size_x = x_sz_upd;
58375850 }
58385851 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5839- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
5852+ ctx->prealloc_size_y = y_sz_upd;
58405853 }
58415854
58425855 // Request descriptor sets
@@ -5881,7 +5894,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58815894 d_Y = ctx->prealloc_y;
58825895 } else if (quantize_y) {
58835896 d_Y = ctx->prealloc_y;
5884- GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128 ) * 128 );
5897+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144 ) * 144 );
58855898 } else {
58865899 d_Y = d_Qy;
58875900 y_buf_offset = qy_buf_offset;
@@ -5923,6 +5936,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59235936 groups_x = CEIL_DIV(groups_x, groups_z);
59245937 }
59255938
5939+ // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
5940+ uint32_t y_sz_total = y_sz * ne12 * ne13;
5941+ if (quantize_y) {
5942+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
5943+ }
5944+
59265945 // compute
59275946 const vk_mat_vec_push_constants pc = {
59285947 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -5931,7 +5950,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59315950 };
59325951 ggml_vk_sync_buffers(subctx);
59335952 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
5934- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
5953+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
59355954 pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
59365955}
59375956
0 commit comments