Skip to content

Commit 05c4cae

Browse files
committed
vulkan: coopmat2 mul_mat optimizations
- Increase tile size for k-quants, to match non-k-quants - Choose more carefully between large and medium tiles, considering how it interacts with split_k - Allow larger/non-power of two split_k, and make the splits a multiple of 256 - Use split_k==3 to when >1/2 and <=2/3 of the SMs would hae been used
1 parent 8ad7b3e commit 05c4cae

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,12 +2068,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
20682068
s_mmq_wg_denoms = { 32, 64, 1 };
20692069

20702070
// spec constants and tile sizes for quant matmul (Qi_K)
2071-
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
2072-
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
2073-
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
2074-
l_mmq_wg_denoms_k = { 64, 128, 1 };
2075-
m_mmq_wg_denoms_k = { 32, 64, 1 };
2076-
s_mmq_wg_denoms_k = { 32, 32, 1 };
2071+
l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
2072+
m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
2073+
s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
2074+
l_mmq_wg_denoms_k = { 128, 256, 1 };
2075+
m_mmq_wg_denoms_k = { 128, 128, 1 };
2076+
s_mmq_wg_denoms_k = { 32, 64, 1 };
20772077

20782078
// spec constants and tile sizes for quant matmul_id
20792079
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
@@ -4943,26 +4943,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
49434943
ggml_vk_queue_command_pools_cleanup(dst->device);
49444944
}
49454945

4946-
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
4946+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
49474947
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
49484948

49494949
uint32_t split_k = 1;
4950-
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
4950+
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
49514951
// If k is 'large' and the SMs will fill less than halfway, use split_k.
49524952
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
49534953
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
4954-
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
4955-
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
4956-
// Clamp to 2 or 4
4957-
split_k = std::min(split_k, 4u);
4958-
if (split_k == 3) {
4959-
split_k = 2;
4954+
4955+
if (k >= 2048) {
4956+
if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
4957+
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
4958+
} else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
4959+
split_k = 3;
49604960
}
4961-
if (ctx->device->coopmat2) {
4962-
// coopmat2 shader expects splits to be aligned to 256
4963-
while (split_k > 1 && ((k / split_k) % 256) != 0) {
4964-
split_k /= 2;
4961+
// Cap the split at 8x. Unless k is huge this is a lot of overhead.
4962+
split_k = std::min(split_k, 8u);
4963+
4964+
// ggml_vk_matmul will align the splits to be a multiple of 256.
4965+
// If this rounded up size would cause the last split to be empty,
4966+
// then reduce the split count.
4967+
while (true) {
4968+
if (split_k == 1) {
4969+
break;
4970+
}
4971+
uint32_t k_split = CEIL_DIV(k, split_k);
4972+
k_split = ROUNDUP_POW2(k_split, 256);
4973+
if (k_split * (split_k - 1) < k) {
4974+
break;
49654975
}
4976+
split_k--;
49664977
}
49674978
}
49684979
}
@@ -4974,9 +4985,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
49744985
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
49754986

49764987
if (ctx->device->coopmat2) {
4988+
const uint32_t shader_core_count = ctx->device->shader_core_count;
4989+
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
4990+
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
4991+
49774992
// Use large shader when the N dimension is greater than the medium shader's tile size
49784993
uint32_t crossover_large = mmp->m->wg_denoms[1];
4979-
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
4994+
4995+
// Prefer large over medium if either:
4996+
// - medium or large tiles would overfill the GPU
4997+
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
4998+
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
4999+
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
5000+
// split_k==3 with large tiles likely better than medium tiles with no split_k.
5001+
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
5002+
5003+
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
49805004
return aligned ? mmp->a_l : mmp->l;
49815005
}
49825006
// Use medium shader when the N dimension is greater than the small shader's tile size
@@ -5020,7 +5044,11 @@ static void ggml_vk_matmul(
50205044

50215045
GGML_ASSERT(batch_stride_d == m * n);
50225046

5023-
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
5047+
// Round the split size up to a multiple of 256 (k-quant alignment)
5048+
uint32_t k_split = CEIL_DIV(k, split_k);
5049+
k_split = ROUNDUP_POW2(k_split, 256);
5050+
5051+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
50245052
// Make sure enough workgroups get assigned for split k to work
50255053
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
50265054
ggml_vk_sync_buffers(subctx);

0 commit comments

Comments
 (0)