2929
3030#include " ggml-vulkan-shaders.hpp"
3131
32+ #define ROUNDUP_POW2 (M, N ) (((M) + (N) - 1 ) & ~((N) - 1 ))
3233#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3334
3435#define VK_VENDOR_ID_AMD 0x1002
@@ -368,6 +369,7 @@ struct vk_mat_mat_push_constants {
368369 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
369370 uint32_t k_split;
370371 uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
372+ uint32_t padded_N;
371373};
372374struct vk_mat_vec_push_constants {
373375 uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +382,7 @@ struct vk_mat_mat_id_push_constants {
380382 uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
381383 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
382384 uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
385+ uint32_t padded_N;
383386};
384387struct vk_mat_vec_id_push_constants {
385388 uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -3882,18 +3885,19 @@ static void ggml_vk_matmul(
38823885 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
38833886 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
38843887 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3885- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
3888+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
3889+ uint32_t padded_n) {
38863890 VK_LOG_DEBUG (" ggml_vk_matmul(a: (" << a.buffer ->buffer << " , " << a.offset << " , " << a.size << " ), b: (" << b.buffer ->buffer << " , " << b.offset << " , " << b.size << " ), d: (" << d.buffer ->buffer << " , " << d.offset << " , " << d.size << " ), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer ->buffer : VK_NULL_HANDLE) << " , " << split_k_buffer.offset << " , " << split_k_buffer.size << " ), m: " << m << " , n: " << n << " , k: " << k << " , stride_a: " << stride_a << " , stride_b: " << stride_b << " , stride_d: " << stride_d << " , batch_stride_a: " << batch_stride_a << " , batch_stride_b: " << batch_stride_b << " , batch_stride_d: " << batch_stride_d << " , split_k: " << split_k << " , batch: " << batch << " , ne02: " << ne02 << " , ne12: " << ne12 << " , broadcast2: " << broadcast2 << " , broadcast3: " << broadcast3 << " )" );
38873891 ggml_vk_sync_buffers (subctx);
38883892 if (split_k == 1 ) {
3889- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
3893+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
38903894 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d }, sizeof (vk_mat_mat_push_constants), &pc, { m, n, batch });
38913895 return ;
38923896 }
38933897
38943898 GGML_ASSERT (batch_stride_d == m * n);
38953899
3896- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3 };
3900+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
38973901 // Make sure enough workgroups get assigned for split k to work
38983902 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof (vk_mat_mat_push_constants), &pc1, { (CEIL_DIV (m, pipeline->wg_denoms [0 ]) * pipeline->wg_denoms [0 ]) * split_k, n, batch });
38993903 ggml_vk_sync_buffers (subctx);
@@ -3937,14 +3941,15 @@ static void ggml_vk_matmul_id(
39373941 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
39383942 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
39393943 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3940- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
3944+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
3945+ uint32_t padded_n) {
39413946 VK_LOG_DEBUG (" ggml_vk_matmul_id(a: (" << a.buffer ->buffer << " , " << a.offset << " , " << a.size << " ), b: (" << b.buffer ->buffer << " , " << b.offset << " , " << b.size << " ), d: (" << d.buffer ->buffer << " , " << d.offset << " , " << d.size << " ), ids: (" << ids.buffer ->buffer << " , " << ids.offset << " , " << ids.size << " ), " <<
39423947 " m: " << m << " , n: " << n << " , k: " << k << " , stride_a: " << stride_a << " , stride_b: " << stride_b << " , stride_d: " << stride_d << " , " <<
39433948 " batch_stride_a: " << batch_stride_a << " , batch_stride_b: " << batch_stride_b << " , batch_stride_d: " << batch_stride_d << " , " <<
39443949 " n_as: " << n_as << " , nei0: " << nei0 << " , nei1: " << nei1 << " , nbi1: " << nbi1 << " , ne11: " << ne11 << " )" );
39453950 ggml_vk_sync_buffers (subctx);
39463951 const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3947- nei0, nei1, nbi1, ne11 };
3952+ nei0, nei1, nbi1, ne11, padded_n };
39483953 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d, ids }, sizeof (vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
39493954}
39503955
@@ -4106,15 +4111,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
41064111 // Not implemented
41074112 GGML_ASSERT (y_non_contig || !qy_needs_dequant); // NOLINT
41084113
4109- const int x_ne = ne01 * ne00;
4110- const int y_ne = ne11 * ne10;
4111- const int d_ne = ne11 * ne01;
4112-
41134114 const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
41144115 const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8 ;
41154116
41164117 vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
41174118
4119+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4120+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4121+ const int x_ne = ne01 * ne00;
4122+ const int y_ne = padded_n * ne10;
4123+ const int d_ne = ne11 * ne01;
4124+
41184125 const uint32_t split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
41194126
41204127 const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
@@ -4237,7 +4244,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
42374244 { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k , 0 , d_sz * ne12 * ne13 * split_k },
42384245 ne01, ne11, ne10,
42394246 ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4240- split_k, ne12*ne13, ne02, ne12, r2, r3
4247+ split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
42414248 ); // NOLINT
42424249}
42434250
@@ -4688,15 +4695,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
46884695 // Not implemented
46894696 GGML_ASSERT (y_non_contig || !qy_needs_dequant); // NOLINT
46904697
4691- const uint64_t x_ne = ne01 * ne00;
4692- const uint64_t y_ne = ne11 * ne10;
4693- const uint64_t d_ne = ne21 * ne20;
4694-
46954698 const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_id_pipeline_align (ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
46964699 const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8 ;
46974700
46984701 vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline (ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
46994702
4703+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4704+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4705+ const uint64_t x_ne = ne01 * ne00;
4706+ const uint64_t y_ne = padded_n * ne10;
4707+ const uint64_t d_ne = ne21 * ne20;
4708+
47004709 const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
47014710 const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
47024711 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -4815,7 +4824,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
48154824 { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
48164825 ne01, ne21, ne10, ne10, ne10, ne01,
48174826 stride_batch_x, stride_batch_y, ne20*ne21,
4818- n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11
4827+ n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11, padded_n
48194828 ); // NOLINT
48204829}
48214830
@@ -6775,7 +6784,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
67756784 ctx, subctx, p, ggml_vk_subbuffer (d_X), ggml_vk_subbuffer (d_Y), ggml_vk_subbuffer (d_D), ggml_vk_subbuffer (ctx->prealloc_split_k ),
67766785 m, n, k,
67776786 k, k, m, k*m, k*n, m*n,
6778- split_k, batch, batch, batch, 1 , 1
6787+ split_k, batch, batch, batch, 1 , 1 , n
67796788 );
67806789 }
67816790 ggml_vk_ctx_end (subctx);
@@ -7120,7 +7129,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
71207129 ctx, subctx, p, ggml_vk_subbuffer (qx_buf), ggml_vk_subbuffer (y_buf), ggml_vk_subbuffer (d_buf), ggml_vk_subbuffer (ctx->prealloc_split_k ),
71217130 m, n, k,
71227131 k, k, m, k*m, k*n, m*n,
7123- split_k, batch, batch, batch, 1 , 1
7132+ split_k, batch, batch, batch, 1 , 1 , n
71247133 );
71257134 }
71267135 ggml_vk_ctx_end (subctx);
0 commit comments