Skip to content

Commit c9d7147

Browse files
committed
vulkan: fix shmem overrun in mmq id shader
1 parent 16724b5 commit c9d7147

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
8282

8383
#include "mul_mmq_shmem_types.glsl"
8484

85+
#ifdef MUL_MAT_ID
86+
#define BK_STEP 1
87+
#else
8588
#ifndef BK_STEP
8689
#define BK_STEP 4
8790
#endif
91+
#endif
8892

8993
// Shared memory cache
9094
shared block_a_cache buf_a[BM * BK_STEP];

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct block_a_cache {
2727
#elif defined(DATA_A_Q8_0)
2828
#define QUANT_R_MMQ 1
2929
// AMD likes 4, Intel likes 1 and Nvidia likes 2
30-
#define BK_STEP 1
30+
// #define BK_STEP 1
3131
struct block_a_cache {
3232
int32_t qs[32/4];
3333
FLOAT_TYPE dm;

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6880,6 +6880,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68806880
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
68816881
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
68826882

6883+
// gpt-oss issue with Vulkan mmq_id
6884+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
6885+
68836886
for (ggml_type type_a : base_types) {
68846887
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
68856888
for (int n_mats : {4, 8}) {

0 commit comments

Comments
 (0)