@@ -2106,12 +2106,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
21062106        s_mmq_wg_denoms = { 32,  64,  1 };
21072107
21082108        // spec constants and tile sizes for quant matmul (Qi_K)
2109-         l_warptile_mmq_k = { 256, 64, 128 , 64,   1 };
2110-         m_warptile_mmq_k = { 256, 32,  64 , 64,  0  };
2111-         s_warptile_mmq_k = { 256, 32,  32 , 128, 0 };
2112-         l_mmq_wg_denoms_k = { 64, 128 , 1 };
2113-         m_mmq_wg_denoms_k = { 32,  64 , 1 };
2114-         s_mmq_wg_denoms_k = { 32,  32,  1 };
2109+         l_warptile_mmq_k = { 256, 128, 256 , 64, 1 };
2110+         m_warptile_mmq_k = { 256, 128, 128 , 64, 1  };
2111+         s_warptile_mmq_k = { 256, 32,  64 , 128, 0 };
2112+         l_mmq_wg_denoms_k = { 128, 256 , 1 };
2113+         m_mmq_wg_denoms_k = { 128, 128 , 1 };
2114+         s_mmq_wg_denoms_k = { 32,  64,   1 };
21152115
21162116        // spec constants and tile sizes for quant matmul_id
21172117        l_warptile_mmqid = { 256, 128, 128, 16, 0 };
@@ -5022,26 +5022,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
50225022    ggml_vk_queue_command_pools_cleanup(dst->device);
50235023}
50245024
5025- static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int  m, int  n, int  k, const vk_pipeline& pipeline) {
5025+ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t  m, uint32_t  n, uint32_t  k, const vk_pipeline& pipeline) {
50265026    VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
50275027
50285028    uint32_t split_k = 1;
5029-     if (ctx->device->shader_core_count != 0 && m >= (int) pipeline->wg_denoms[0] && n >= (int) pipeline->wg_denoms[1]) {
5029+     if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
50305030        // If k is 'large' and the SMs will fill less than halfway, use split_k.
50315031        uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
50325032        uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
5033-         if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { 
5034-             split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); 
5035-             // Clamp to 2 or 4 
5036-             split_k = std::min(split_k, 4u );
5037-             if (split_k ==  3) {
5038-                 split_k = 2 ;
5033+ 
5034+         if (k >= 2048) { 
5035+             if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { 
5036+                  split_k = ctx->device->shader_core_count / (m_tiles * n_tiles );
5037+             } else  if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 /  3) {
5038+                 split_k = 3 ;
50395039            }
5040-             if (ctx->device->coopmat2) {
5041-                 // coopmat2 shader expects splits to be aligned to 256
5042-                 while (split_k > 1 && ((k / split_k) % 256) != 0) {
5043-                     split_k /= 2;
5040+             // Cap the split at 8x. Unless k is huge this is a lot of overhead.
5041+             split_k = std::min(split_k, 8u);
5042+ 
5043+             // ggml_vk_matmul will align the splits to be a multiple of 256.
5044+             // If this rounded up size would cause the last split to be empty,
5045+             // then reduce the split count.
5046+             while (true) {
5047+                 if (split_k == 1) {
5048+                     break;
5049+                 }
5050+                 uint32_t k_split = CEIL_DIV(k, split_k);
5051+                 k_split = ROUNDUP_POW2(k_split, 256);
5052+                 if (k_split * (split_k - 1) < k) {
5053+                     break;
50445054                }
5055+                 split_k--;
50455056            }
50465057        }
50475058    }
@@ -5053,9 +5064,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
50535064    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
50545065
50555066    if (ctx->device->coopmat2) {
5067+         const uint32_t shader_core_count = ctx->device->shader_core_count;
5068+         const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
5069+         const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
5070+ 
50565071        // Use large shader when the N dimension is greater than the medium shader's tile size
50575072        uint32_t crossover_large = mmp->m->wg_denoms[1];
5058-         if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
5073+ 
5074+         // Prefer large over medium if either:
5075+         // - medium or large tiles would overfill the GPU
5076+         // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
5077+         //   (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
5078+         bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
5079+                             // split_k==3 with large tiles likely better than medium tiles with no split_k.
5080+                             (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
5081+ 
5082+         if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
50595083            return aligned ? mmp->a_l : mmp->l;
50605084        }
50615085        // Use medium shader when the N dimension is greater than the small shader's tile size
@@ -5099,7 +5123,11 @@ static void ggml_vk_matmul(
50995123
51005124    GGML_ASSERT(batch_stride_d == m * n);
51015125
5102-     const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
5126+     // Round the split size up to a multiple of 256 (k-quant alignment)
5127+     uint32_t k_split = CEIL_DIV(k, split_k);
5128+     k_split = ROUNDUP_POW2(k_split, 256);
5129+ 
5130+     const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
51035131    // Make sure enough workgroups get assigned for split k to work
51045132    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
51055133    ggml_vk_sync_buffers(subctx);
0 commit comments