Skip to content

Commit 34981fa

Browse files
committed
metal : optimize mul_mm_id id gathering
1 parent b534e39 commit 34981fa

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ typedef struct {
320320
} ggml_metal_kargs_mul_mv_ext;
321321

322322
typedef struct {
323+
int32_t ne02;
323324
int32_t ne10;
324325
int32_t ne11; // n_expert_used (bcast)
325326
uint64_t nb11;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3895,6 +3895,7 @@ static int ggml_metal_encode_node(
38953895

38963896
{
38973897
ggml_metal_kargs_mul_mm_id_map0 args = {
3898+
ne02,
38983899
ne10,
38993900
ne11, // n_expert_used (bcast)
39003901
nb11,
@@ -3908,13 +3909,20 @@ static int ggml_metal_encode_node(
39083909

39093910
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
39103911

3912+
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
3913+
3914+
const size_t smem = ne02*ne20*sizeof(uint16_t);
3915+
3916+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3917+
39113918
[encoder setComputePipelineState:pipeline];
39123919
[encoder setBytes:&args length:sizeof(args) atIndex:0];
39133920
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
39143921
[encoder setBuffer: h_tpe offset:0 atIndex:2];
39153922
[encoder setBuffer: h_ids offset:0 atIndex:3];
3923+
[encoder setThreadgroupMemoryLength:ne02*ne20*sizeof(uint16_t) atIndex:0];
39163924

3917-
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3925+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
39183926
}
39193927

39203928
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7480,31 +7480,45 @@ kernel void kernel_mul_mm_id_map0(
74807480
device const char * src2,
74817481
device char * htpe,
74827482
device char * hids,
7483-
uint3 tgpig[[threadgroup_position_in_grid]],
7484-
ushort3 tpitg[[thread_position_in_threadgroup]],
7485-
ushort3 ntg[[threads_per_threadgroup]]) {
7486-
const int ide = tgpig[0]; // expert id
7483+
threadgroup char * shmem [[threadgroup(0)]],
7484+
ushort tpitg[[thread_position_in_threadgroup]],
7485+
ushort ntg[[threads_per_threadgroup]]) {
7486+
const short ide = tpitg; // expert id
74877487

7488-
int n_all = 0;
7488+
uint32_t n_all = 0;
74897489

74907490
device int32_t * ids_i32 = (device int32_t *) (hids);
74917491

7492-
for (int i21 = 0; i21 < args.ne21; i21++) { // n_tokens
7493-
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
7492+
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
7493+
{
7494+
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
74947495

7495-
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
7496-
if (src2_i32[i20] != ide) {
7497-
continue;
7496+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20;
7497+
7498+
for (int i20 = 0; i20 < args.ne20 && i21 + tpitg < args.ne21; i20++) {
7499+
sids[i20] = src2_i32[i20];
74987500
}
7501+
}
7502+
7503+
threadgroup_barrier(mem_flags::mem_threadgroup);
74997504

7500-
ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
7505+
for (int t = 0; t < ntg && i21 + t < args.ne21; t++) {
7506+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20;
75017507

7502-
++n_all;
7508+
for (int i20 = 0; i20 < args.ne20; i20++) {
7509+
if (sids[i20] == ide) {
7510+
ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7511+
++n_all;
7512+
break;
7513+
}
7514+
}
75037515
}
7516+
7517+
threadgroup_barrier(mem_flags::mem_threadgroup);
75047518
}
75057519

7506-
device int32_t * tpe_i32 = (device int32_t *) (htpe);
7507-
tpe_i32[ide] = n_all;
7520+
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
7521+
tpe_u32[ide] = n_all;
75087522
}
75097523

75107524
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
@@ -7532,10 +7546,10 @@ kernel void kernel_mul_mm_id(
75327546
const int r1 = tgpig.x;
75337547
const int im = tgpig.z; // expert
75347548

7535-
device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7536-
device const int32_t * ids_i32 = (device const int32_t *) (hids);
7549+
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
7550+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
75377551

7538-
const int neh1 = tpe_i32[im];
7552+
const uint32_t neh1 = tpe_u32[im];
75397553

75407554
if (r1*BLOCK_SIZE_N >= neh1) {
75417555
return;
@@ -7561,9 +7575,9 @@ kernel void kernel_mul_mm_id(
75617575

75627576
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75637577

7564-
const int i11 = (id % args.ne20) % args.ne11;
7565-
const int i12 = (id / args.ne20);
7566-
const int i13 = 0;
7578+
const short i11 = (id % args.ne20) % args.ne11;
7579+
const short i12 = (id / args.ne20);
7580+
const short i13 = 0;
75677581

75687582
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
75697583
const short offset1 = il/nl;
@@ -7632,17 +7646,18 @@ kernel void kernel_mul_mm_id(
76327646
threadgroup float * temp_str = ((threadgroup float *) shmem) \
76337647
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
76347648

7649+
#pragma unroll(8)
76357650
for (short i = 0; i < 8; i++) {
76367651
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
76377652
}
76387653

76397654
threadgroup_barrier(mem_flags::mem_threadgroup);
76407655

7641-
for (int j = sgitg; j < n_cols; j += 4) {
7656+
for (short j = sgitg; j < n_cols; j += 4) {
76427657
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76437658

7644-
const int ide = id % args.ne20;
7645-
const int idt = id / args.ne20;
7659+
const short ide = id % args.ne20;
7660+
const short idt = id / args.ne20;
76467661

76477662
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
76487663
device float4 * D4 = (device float4 *) D;

0 commit comments

Comments
 (0)