@@ -2919,7 +2919,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
29192919 uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
29202920 uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
29212921
2922- const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
2922+ const bool s = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
29232923
29242924 for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
29252925 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
@@ -2970,8 +2970,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
29702970 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
29712971 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
29722972 }
2973+ }
29732974
29742975#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2976+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
29752977 if (device->integer_dot_product) {
29762978 const uint32_t subgroup_size = (device->subgroup_size_control && device->vendor_id == VK_VENDOR_ID_INTEL) ? device->subgroup_min_size : device->subgroup_size;
29772979 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
@@ -2988,8 +2990,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29882990 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_len, mul_mat_vec_q8_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {subgroup_size, 1*rm_stdq, i+1}, 1, true);
29892991 }
29902992 }
2991- #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
29922993 }
2994+ #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
29932995
29942996 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
29952997 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -5769,7 +5771,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57695771
57705772 if (dryrun) {
57715773 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
5772- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
5774+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
5775+ if (quantize_y) {
5776+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
5777+ }
57735778 const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
57745779 if (
57755780 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -5781,7 +5786,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57815786 ctx->prealloc_size_x = x_sz_upd;
57825787 }
57835788 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5784- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
5789+ ctx->prealloc_size_y = y_sz_upd;
57855790 }
57865791 if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
57875792 ctx->prealloc_size_split_k = split_k_size;
@@ -5835,7 +5840,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
58355840 GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
58365841 } else if (quantize_y) {
58375842 d_Y = ctx->prealloc_y;
5838- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1) );
5843+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144 );
58395844 } else {
58405845 d_Y = d_Qy;
58415846 y_buf_offset = qy_buf_offset;
@@ -5889,10 +5894,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
58895894 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
58905895 }
58915896
5897+ uint32_t y_sz_total = y_sz * ne12 * ne13;
5898+ if (quantize_y) {
5899+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
5900+ }
5901+
58925902 // compute
58935903 ggml_vk_matmul(
58945904 ctx, subctx, pipeline,
5895- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
5905+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
58965906 { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
58975907 ne01, ne11, ne10,
58985908 ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
@@ -6010,7 +6020,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60106020
60116021 if (dryrun) {
60126022 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
6013- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
6023+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
6024+ if (quantize_y) {
6025+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
6026+ }
60146027 if (
60156028 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
60166029 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
@@ -6020,7 +6033,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60206033 ctx->prealloc_size_x = x_sz_upd;
60216034 }
60226035 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
6023- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
6036+ ctx->prealloc_size_y = y_sz_upd;
60246037 }
60256038
60266039 // Request descriptor sets
@@ -6065,7 +6078,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60656078 d_Y = ctx->prealloc_y;
60666079 } else if (quantize_y) {
60676080 d_Y = ctx->prealloc_y;
6068- GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128 ) * 128 );
6081+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144 ) * 144 );
60696082 } else {
60706083 d_Y = d_Qy;
60716084 y_buf_offset = qy_buf_offset;
@@ -6121,14 +6134,20 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
61216134 groups_x = CEIL_DIV(groups_x, groups_z);
61226135 }
61236136
6137+ // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
6138+ uint32_t y_sz_total = y_sz * ne12 * ne13;
6139+ if (quantize_y) {
6140+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
6141+ }
6142+
61246143 // compute
61256144 const vk_mat_vec_push_constants pc = {
61266145 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
61276146 stride_batch_x, stride_batch_y, stride_batch_d,
61286147 (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
61296148 };
61306149 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
6131- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
6150+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
61326151 pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
61336152
61346153 if (x_non_contig) {
0 commit comments