Skip to content

Commit d2a2673

Browse files
0cc4mggerganov
andauthored
vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader * metal : fix mul_mm_id --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 13002a0 commit d2a2673

File tree

4 files changed

+9
-2
lines changed

4 files changed

+9
-2
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
677677
char name[256];
678678

679679
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
680-
snprintf(name, 256, "%s", base);
680+
snprintf(name, 256, "%s_ne02=%d", base, ne02);
681681

682682
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
683683
if (res) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
8282

8383
#include "mul_mmq_shmem_types.glsl"
8484

85+
#ifdef MUL_MAT_ID
86+
#define BK_STEP 1
87+
#else
8588
#ifndef BK_STEP
8689
#define BK_STEP 4
8790
#endif
91+
#endif
8892

8993
// Shared memory cache
9094
shared block_a_cache buf_a[BM * BK_STEP];

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct block_a_cache {
2727
#elif defined(DATA_A_Q8_0)
2828
#define QUANT_R_MMQ 1
2929
// AMD likes 4, Intel likes 1 and Nvidia likes 2
30-
#define BK_STEP 1
30+
// #define BK_STEP 1
3131
struct block_a_cache {
3232
int32_t qs[32/4];
3333
FLOAT_TYPE dm;

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6880,6 +6880,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68806880
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
68816881
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
68826882

6883+
// gpt-oss issue with Vulkan mmq_id
6884+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
6885+
68836886
for (ggml_type type_a : base_types) {
68846887
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
68856888
for (int n_mats : {4, 8}) {

0 commit comments

Comments
 (0)