@@ -2847,7 +2847,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
28472847 ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
28482848 }
28492849 }
2850- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2850+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
28512851
28522852 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
28532853 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -5742,7 +5742,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
57425742 const uint64_t ne00 = src0->ne[0];
57435743 const uint64_t ne01 = src0->ne[1];
57445744 const uint64_t ne02 = src0->ne[2];
5745- // const uint64_t ne03 = src0->ne[3];
5745+ const uint64_t ne03 = src0->ne[3];
57465746
57475747 const uint64_t nb01 = src0->nb[1];
57485748 const uint64_t nb02 = src0->nb[2];
@@ -5754,7 +5754,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
57545754 const uint64_t ne12 = src1->ne[2];
57555755 // const uint64_t ne13 = src1->ne[3];
57565756
5757+ const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
5758+ const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
5759+ const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
5760+
57575761 GGML_ASSERT(ne11 == 1);
5762+ GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
57585763
57595764 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
57605765 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
@@ -5770,7 +5775,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
57705775 src1_uma = d_Qy != nullptr;
57715776 }
57725777
5773- const uint64_t d_ne = ne01 * ne11 * ne12;
5778+ const uint64_t d_ne = ne01 * ne11 * ne12 * ne03 ;
57745779
57755780 const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
57765781 const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
@@ -5805,10 +5810,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
58055810 const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
58065811
58075812 // compute
5808- const std::array<uint32_t, 9 > pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5813+ const std::array<uint32_t, 12 > pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
58095814 ggml_vk_sync_buffers(subctx);
58105815 ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
5811- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1 , (uint32_t)ne01, (uint32_t)ne12 });
5816+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03 , (uint32_t)ne01, (uint32_t)ne12 });
58125817}
58135818
58145819static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
0 commit comments