2929
3030#include  " ggml-vulkan-shaders.hpp" 
3131
32+ #define  ROUNDUP_POW2 (M, N ) (((M) + (N) - 1 ) & ~((N) - 1 ))
3233#define  CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3334
3435#define  VK_VENDOR_ID_AMD  0x1002 
@@ -368,6 +369,7 @@ struct vk_mat_mat_push_constants {
368369    uint32_t  batch_stride_a; uint32_t  batch_stride_b; uint32_t  batch_stride_d;
369370    uint32_t  k_split;
370371    uint32_t  ne02; uint32_t  ne12; uint32_t  broadcast2; uint32_t  broadcast3;
372+     uint32_t  padded_N;
371373};
372374struct  vk_mat_vec_push_constants  {
373375    uint32_t  ncols; uint32_t  stride_a; uint32_t  stride_b; uint32_t  stride_d;
@@ -380,6 +382,7 @@ struct vk_mat_mat_id_push_constants {
380382    uint32_t  stride_a; uint32_t  stride_b; uint32_t  stride_d;
381383    uint32_t  batch_stride_a; uint32_t  batch_stride_b; uint32_t  batch_stride_d;
382384    uint32_t  nei0; uint32_t  nei1; uint32_t  nbi1; uint32_t  ne11;
385+     uint32_t  padded_N;
383386};
384387struct  vk_mat_vec_id_push_constants  {
385388    uint32_t  ncols; uint32_t  stride_a; uint32_t  stride_b; uint32_t  stride_d;
@@ -3882,18 +3885,19 @@ static void ggml_vk_matmul(
38823885        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
38833886        uint32_t  m, uint32_t  n, uint32_t  k, uint32_t  stride_a, uint32_t  stride_b, uint32_t  stride_d,
38843887        uint32_t  batch_stride_a, uint32_t  batch_stride_b, uint32_t  batch_stride_d,
3885-         uint32_t  split_k, uint32_t  batch, uint32_t  ne02, uint32_t  ne12, uint32_t  broadcast2, uint32_t  broadcast3) {
3888+         uint32_t  split_k, uint32_t  batch, uint32_t  ne02, uint32_t  ne12, uint32_t  broadcast2, uint32_t  broadcast3,
3889+         uint32_t  padded_n) {
38863890        VK_LOG_DEBUG (" ggml_vk_matmul(a: ("   << a.buffer ->buffer  << " , "   << a.offset  << " , "   << a.size  << " ), b: ("   << b.buffer ->buffer  << " , "   << b.offset  << " , "   << b.size  << " ), d: ("   << d.buffer ->buffer  << " , "   << d.offset  << " , "   << d.size  << " ), split_k: ("   << (split_k_buffer.buffer  != nullptr  ? split_k_buffer.buffer ->buffer  : VK_NULL_HANDLE) << " , "   << split_k_buffer.offset  << " , "   << split_k_buffer.size  << " ), m: "   << m << " , n: "   << n << " , k: "   << k << " , stride_a: "   << stride_a << " , stride_b: "   << stride_b << " , stride_d: "   << stride_d << " , batch_stride_a: "   << batch_stride_a << " , batch_stride_b: "   << batch_stride_b << " , batch_stride_d: "   << batch_stride_d << " , split_k: "   << split_k << " , batch: "   << batch << " , ne02: "   << ne02 << " , ne12: "   << ne12 << " , broadcast2: "   << broadcast2 << " , broadcast3: "   << broadcast3 << " )"  );
38873891    ggml_vk_sync_buffers (subctx);
38883892    if  (split_k == 1 ) {
3889-         const  vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
3893+         const  vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n  };
38903894        ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d }, sizeof (vk_mat_mat_push_constants), &pc, { m, n, batch });
38913895        return ;
38923896    }
38933897
38943898    GGML_ASSERT (batch_stride_d == m * n);
38953899
3896-     const  vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3 };
3900+     const  vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n  };
38973901    //  Make sure enough workgroups get assigned for split k to work
38983902    ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof (vk_mat_mat_push_constants), &pc1, { (CEIL_DIV (m, pipeline->wg_denoms [0 ]) * pipeline->wg_denoms [0 ]) * split_k, n, batch });
38993903    ggml_vk_sync_buffers (subctx);
@@ -3937,14 +3941,15 @@ static void ggml_vk_matmul_id(
39373941        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
39383942        uint32_t  m, uint32_t  n, uint32_t  k, uint32_t  stride_a, uint32_t  stride_b, uint32_t  stride_d,
39393943        uint32_t  batch_stride_a, uint32_t  batch_stride_b, uint32_t  batch_stride_d,
3940-         uint32_t  n_as, uint32_t  nei0, uint32_t  nei1, uint32_t  nbi1, uint32_t  ne11) {
3944+         uint32_t  n_as, uint32_t  nei0, uint32_t  nei1, uint32_t  nbi1, uint32_t  ne11,
3945+         uint32_t  padded_n) {
39413946    VK_LOG_DEBUG (" ggml_vk_matmul_id(a: ("   << a.buffer ->buffer  << " , "   << a.offset  << " , "   << a.size  << " ), b: ("   << b.buffer ->buffer  << " , "   << b.offset  << " , "   << b.size  << " ), d: ("   << d.buffer ->buffer  << " , "   << d.offset  << " , "   << d.size  << " ), ids: ("   << ids.buffer ->buffer  << " , "   << ids.offset  << " , "   << ids.size  << " ), "   <<
39423947        " m: "   << m << " , n: "   << n << " , k: "   << k << " , stride_a: "   << stride_a << " , stride_b: "   << stride_b << " , stride_d: "   << stride_d << " , "   <<
39433948        " batch_stride_a: "   << batch_stride_a << " , batch_stride_b: "   << batch_stride_b << " , batch_stride_d: "   << batch_stride_d << " , "   <<
39443949        " n_as: "   << n_as << " , nei0: "   << nei0 << " , nei1: "   << nei1 << " , nbi1: "   << nbi1 << " , ne11: "   << ne11 << " )"  );
39453950    ggml_vk_sync_buffers (subctx);
39463951    const  vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3947-                                               nei0, nei1, nbi1, ne11 };
3952+                                               nei0, nei1, nbi1, ne11, padded_n  };
39483953    ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d, ids }, sizeof (vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
39493954}
39503955
@@ -4106,15 +4111,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
41064111    //  Not implemented
41074112    GGML_ASSERT (y_non_contig || !qy_needs_dequant);  //  NOLINT
41084113
4109-     const  int  x_ne = ne01 * ne00;
4110-     const  int  y_ne = ne11 * ne10;
4111-     const  int  d_ne = ne11 * ne01;
4112- 
41134114    const  uint32_t  kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
41144115    const  bool  aligned = ne10 == kpad && ne01 > 8  && ne11 > 8 ;
41154116
41164117    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
41174118
4119+     //  Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4120+     uint32_t  padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4121+     const  int  x_ne = ne01 * ne00;
4122+     const  int  y_ne = padded_n * ne10;
4123+     const  int  d_ne = ne11 * ne01;
4124+ 
41184125    const  uint32_t  split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
41194126
41204127    const  uint64_t  qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
@@ -4237,7 +4244,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
42374244        { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k , 0 , d_sz * ne12 * ne13 * split_k },
42384245        ne01, ne11, ne10,
42394246        ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4240-         split_k, ne12*ne13, ne02, ne12, r2, r3
4247+         split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n 
42414248    );  //  NOLINT
42424249}
42434250
@@ -4688,15 +4695,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
46884695    //  Not implemented
46894696    GGML_ASSERT (y_non_contig || !qy_needs_dequant);  //  NOLINT
46904697
4691-     const  uint64_t  x_ne = ne01 * ne00;
4692-     const  uint64_t  y_ne = ne11 * ne10;
4693-     const  uint64_t  d_ne = ne21 * ne20;
4694- 
46954698    const  uint32_t  kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_id_pipeline_align (ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
46964699    const  bool  aligned = ne10 == kpad && ne01 > 8  && nei1 > 8 ;
46974700
46984701    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline (ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
46994702
4703+     //  Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4704+     uint32_t  padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4705+     const  uint64_t  x_ne = ne01 * ne00;
4706+     const  uint64_t  y_ne = padded_n * ne10;
4707+     const  uint64_t  d_ne = ne21 * ne20;
4708+ 
47004709    const  uint64_t  qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
47014710    const  uint64_t  qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
47024711    const  uint64_t  x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -4815,7 +4824,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
48154824        { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
48164825        ne01, ne21, ne10, ne10, ne10, ne01,
48174826        stride_batch_x, stride_batch_y, ne20*ne21,
4818-         n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11
4827+         n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11, padded_n 
48194828    );  //  NOLINT
48204829}
48214830
@@ -6775,7 +6784,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
67756784            ctx, subctx, p, ggml_vk_subbuffer (d_X), ggml_vk_subbuffer (d_Y), ggml_vk_subbuffer (d_D), ggml_vk_subbuffer (ctx->prealloc_split_k ),
67766785            m, n, k,
67776786            k, k, m, k*m, k*n, m*n,
6778-             split_k, batch, batch, batch, 1 , 1 
6787+             split_k, batch, batch, batch, 1 , 1 , n 
67796788        );
67806789    }
67816790    ggml_vk_ctx_end (subctx);
@@ -7120,7 +7129,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
71207129            ctx, subctx, p, ggml_vk_subbuffer (qx_buf), ggml_vk_subbuffer (y_buf), ggml_vk_subbuffer (d_buf), ggml_vk_subbuffer (ctx->prealloc_split_k ),
71217130            m, n, k,
71227131            k, k, m, k*m, k*n, m*n,
7123-             split_k, batch, batch, batch, 1 , 1 
7132+             split_k, batch, batch, batch, 1 , 1 , n 
71247133        );
71257134    }
71267135    ggml_vk_ctx_end (subctx);
0 commit comments