Skip to content

Commit 6e219b7

Browse files
committed
metal : mul_mm support ne00 % 32 != 0
1 parent 6f13728 commit 6e219b7

File tree

6 files changed

+74
-23
lines changed

6 files changed

+74
-23
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,19 +438,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
438438
return res;
439439
}
440440

441-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
441+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
442442
char base[256];
443443
char name[256];
444444

445+
const ggml_type tsrc0 = op->src[0]->type;
446+
const ggml_type tsrc1 = op->src[1]->type;
447+
448+
const bool bc = op->src[0]->ne[0] % 32 != 0;
449+
445450
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
446-
snprintf(name, 256, "%s", base);
451+
snprintf(name, 256, "%s_bc=%d", base, bc);
447452

448453
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
449454
if (res) {
450455
return res;
451456
}
452457

453-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
458+
ggml_metal_cv_t cv = ggml_metal_cv_init();
459+
460+
ggml_metal_cv_set_bool(cv, bc, FC_MUL_MM + 0);
461+
462+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
463+
464+
ggml_metal_cv_free(cv);
454465

455466
ggml_metal_pipeline_set_smem(res, 8192);
456467

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
115115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
116116
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
117117
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
118-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
118+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
119119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
121121
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#define FC_FLASH_ATTN_EXT_VEC 200
7777
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
7878
#define FC_MUL_MV 400
79+
#define FC_MUL_MM 500
7980

8081
// kernel argument structs
8182
//

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,21 +1476,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14761476
!ggml_is_transposed(op->src[1]) &&
14771477
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
14781478
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1479-
props_dev->has_simdgroup_mm &&
1480-
ne00 % 32 == 0 && ne00 >= 64 &&
1479+
props_dev->has_simdgroup_mm && ne00 >= 64 &&
14811480
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
14821481
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
14831482

14841483
// some Metal matrix data types require aligned pointers
14851484
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1486-
switch (op->src[0]->type) {
1487-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1488-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1489-
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1490-
default: break;
1491-
}
1485+
//switch (op->src[0]->type) {
1486+
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1487+
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1488+
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1489+
// default: break;
1490+
//}
14921491

1493-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
1492+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
14941493

14951494
ggml_metal_kargs_mul_mm args = {
14961495
/*.ne00 =*/ ne00,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7856,6 +7856,8 @@ kernel void kernel_set_rows_f(
78567856
}
78577857
}
78587858

7859+
constant bool FC_mul_mm_bounds_check [[function_constant(FC_MUL_MM + 0)]];
7860+
78597861
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
78607862
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
78617863
#define BLOCK_SIZE_K 32
@@ -7913,27 +7915,58 @@ kernel void kernel_mul_mm(
79137915
device const block_q * x = (device const block_q *)(src0
79147916
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
79157917

7918+
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
7919+
79167920
device const U * y = (device const U *)(src1
79177921
+ args.nb13*i13
79187922
+ args.nb12*i12
79197923
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
7920-
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
7924+
+ args.nb10*iy);
79217925

79227926
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
79237927
// load data and store to threadgroup memory
7924-
T4x4 temp_a;
7925-
dequantize_func(x, il, temp_a);
7928+
if (is_same<T4x4, block_q>::value) {
7929+
// no need for dequantization
7930+
threadgroup_barrier(mem_flags::mem_threadgroup);
79267931

7927-
threadgroup_barrier(mem_flags::mem_threadgroup);
7932+
if (FC_mul_mm_bounds_check) {
7933+
// bounds checks are required
7934+
#pragma unroll(16)
7935+
for (short i = 0; i < 16; i++) {
7936+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7937+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7938+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T *) x)[16*il + i] : 0;
7939+
}
7940+
} else {
7941+
// do not perform bounds checks
7942+
#pragma unroll(16)
7943+
for (short i = 0; i < 16; i++) {
7944+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7945+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7946+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = ((device T *) x)[i];
7947+
}
7948+
}
7949+
} else {
7950+
T4x4 temp_a;
7951+
dequantize_func(x, il, temp_a);
79287952

7929-
#pragma unroll(16)
7930-
for (short i = 0; i < 16; i++) {
7931-
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7932-
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7933-
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
7953+
threadgroup_barrier(mem_flags::mem_threadgroup);
7954+
7955+
#pragma unroll(16)
7956+
for (short i = 0; i < 16; i++) {
7957+
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7958+
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7959+
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
7960+
}
79347961
}
79357962

7936-
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device U2x4 *) y));
7963+
if (FC_mul_mm_bounds_check) {
7964+
for (short i = 0; i < 8; ++i) {
7965+
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? ((device U *) y)[i] : 0;
7966+
}
7967+
} else {
7968+
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device U2x4 *) y));
7969+
}
79377970

79387971
il = (il + 2 < nl) ? il + 2 : il % 2;
79397972
x = (il < 2) ? x + (2 + nl - 1)/nl : x;

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6293,6 +6293,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62936293
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
62946294
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
62956295

6296+
#if 0
6297+
// test the mat-mat path for Metal
6298+
for (int k = 1; k < 512; ++k) {
6299+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));
6300+
}
6301+
#endif
6302+
62966303
for (auto bs2 : {1,3}) {
62976304
for (auto bs : {1,2,4,8}) {
62986305
for (auto nr : {1,4}) {

0 commit comments

Comments
 (0)