Skip to content

Commit ef7bda3

Browse files
committed
metal : mul_mm_id simplify + add test
1 parent cbc35ad commit ef7bda3

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7526,6 +7526,7 @@ kernel void kernel_mul_mm_id(
75267526
threadgroup char * shmem [[threadgroup(0)]],
75277527
uint3 tgpig[[threadgroup_position_in_grid]],
75287528
ushort tiitg[[thread_index_in_threadgroup]],
7529+
ushort tiisg[[thread_index_in_simdgroup]],
75297530
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
75307531

75317532
threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7631,36 +7632,36 @@ kernel void kernel_mul_mm_id(
76317632
}
76327633

76337634
threadgroup_barrier(mem_flags::mem_threadgroup);
7635+
76347636
threadgroup float * temp_str = ((threadgroup float *) shmem) \
76357637
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
7638+
76367639
for (short i = 0; i < 8; i++) {
76377640
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
76387641
}
76397642

76407643
threadgroup_barrier(mem_flags::mem_threadgroup);
76417644

7642-
if (sgitg == 0) {
7643-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7644-
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
7645+
for (int j = sgitg; j < n_cols; j += 4) {
7646+
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76457647

7646-
const int ide = id % args.ne20;
7647-
const int idt = id / args.ne20;
7648+
const int ide = id % args.ne20;
7649+
const int idt = id / args.ne20;
76487650

7649-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
7650-
device float4 * D4 = (device float4 *) D;
7651+
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
7652+
device float4 * D4 = (device float4 *) D;
76517653

7652-
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7653-
threadgroup float4 * C4 = (threadgroup float4 *) C;
7654+
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
7655+
threadgroup float4 * C4 = (threadgroup float4 *) C;
76547656

7655-
int i = 0;
7656-
for (; i < n_rows/4; i++) {
7657-
*(D4 + i) = *(C4 + i);
7658-
}
7657+
int i = tiisg;
7658+
for (; i < n_rows/4; i += 32) {
7659+
*(D4 + i) = *(C4 + i);
7660+
}
76597661

7660-
i *= 4;
7661-
for (; i < n_rows; i++) {
7662-
*(D + i) = *(C + i);
7663-
}
7662+
i = (4*(n_rows/4)) + tiisg;
7663+
for (; i < n_rows; i += 32) {
7664+
*(D + i) = *(C + i);
76647665
}
76657666
}
76667667
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5997,6 +5997,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
59975997
// test large experts*tokens
59985998
for (bool b : {false, true}) {
59995999
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
6000+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
60006001
}
60016002

60026003
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));

0 commit comments

Comments
 (0)