Skip to content

Commit cbc35ad

Browse files
committed
metal : remove mul_mm_id hsrc1
1 parent 6b88434 commit cbc35ad

File tree

3 files changed

+34
-73
lines changed

3 files changed

+34
-73
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -324,36 +324,24 @@ typedef struct {
324324
int32_t ne11; // n_expert_used (bcast)
325325
uint64_t nb11;
326326
uint64_t nb12;
327-
int32_t neh11; // n_tokens
328-
uint64_t nbh11;
327+
int32_t ne21; // n_tokens
329328
int32_t ne20; // n_expert_used
330329
uint64_t nb21;
331330
} ggml_metal_kargs_mul_mm_id_map0;
332331

333-
typedef struct {
334-
int32_t ne20; // n_expert_used
335-
int32_t neh0;
336-
int32_t neh1;
337-
uint64_t nbh1;
338-
uint64_t nbh2;
339-
int32_t ne0;
340-
uint64_t nb1;
341-
uint64_t nb2;
342-
} ggml_metal_kargs_mul_mm_id_map1;
343-
344332
typedef struct {
345333
int32_t ne00;
346334
int32_t ne02;
347335
uint64_t nb01;
348336
uint64_t nb02;
349337
uint64_t nb03;
338+
int32_t ne11;
339+
uint64_t nb10;
340+
uint64_t nb11;
341+
uint64_t nb12;
342+
uint64_t nb13;
350343
int32_t ne20;
351344
int32_t ne21;
352-
int32_t neh12;
353-
uint64_t nbh10;
354-
uint64_t nbh11;
355-
uint64_t nbh12;
356-
uint64_t nbh13;
357345
int32_t ne0;
358346
int32_t ne1;
359347
int16_t r2;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3876,22 +3876,6 @@ static int ggml_metal_encode_node(
38763876
default: break;
38773877
}
38783878

3879-
const int64_t neh10 = ne10; // n_embd
3880-
const int64_t neh11 = ne21; // n_tokens
3881-
const int64_t neh12 = ne02; // n_expert
3882-
3883-
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
3884-
const uint64_t nbh11 = nbh10*neh10;
3885-
const uint64_t nbh12 = nbh11*neh11;
3886-
const uint64_t nbh13 = nbh12*neh12;
3887-
3888-
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
3889-
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3890-
if (!h_src1) {
3891-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3892-
return 0;
3893-
}
3894-
38953879
// tokens per expert
38963880
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
38973881
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3914,12 +3898,11 @@ static int ggml_metal_encode_node(
39143898

39153899
ggml_metal_kargs_mul_mm_id_map0 args = {
39163900
ne10,
3917-
ne11, // n_expert_used (bcast)
3901+
ne11, // n_expert_used (bcast)
39183902
nb11,
39193903
nb12,
3920-
neh11, // n_tokens
3921-
nbh11,
3922-
ne20, // n_expert_used
3904+
ne21, // n_tokens
3905+
ne20, // n_expert_used
39233906
nb21,
39243907
};
39253908

@@ -3929,11 +3912,9 @@ static int ggml_metal_encode_node(
39293912

39303913
[encoder setComputePipelineState:pipeline];
39313914
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3932-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3933-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3934-
[encoder setBuffer: h_src1 offset:0 atIndex:3];
3935-
[encoder setBuffer: h_tpe offset:0 atIndex:4];
3936-
[encoder setBuffer: h_ids offset:0 atIndex:5];
3915+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
3916+
[encoder setBuffer: h_tpe offset:0 atIndex:2];
3917+
[encoder setBuffer: h_ids offset:0 atIndex:3];
39373918

39383919
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
39393920
}
@@ -3974,13 +3955,13 @@ static int ggml_metal_encode_node(
39743955
/*.nb01 =*/ nb01,
39753956
/*.nb02 =*/ nb02,
39763957
/*.nb03 =*/ nb03,
3958+
/*.ne11 =*/ ne11, // n_expert_used (bcast)
3959+
/*.nb10 =*/ nb10,
3960+
/*.nb11 =*/ nb11,
3961+
/*.nb12 =*/ nb12,
3962+
/*.nb13 =*/ nb13,
39773963
/*.ne20 =*/ ne20, // n_expert_used
39783964
/*.ne21 =*/ ne21, // n_tokens
3979-
/*.neh12 =*/ neh12,
3980-
/*.nbh10 =*/ nbh10,
3981-
/*.nbh11 =*/ nbh11,
3982-
/*.nbh12 =*/ nbh12,
3983-
/*.nbh13 =*/ nbh13,
39843965
/*.ne0 =*/ ne0,
39853966
/*.ne1 =*/ ne1,
39863967
/*.r2 =*/ r2,
@@ -3990,7 +3971,7 @@ static int ggml_metal_encode_node(
39903971
[encoder setComputePipelineState:pipeline];
39913972
[encoder setBytes:&args length:sizeof(args) atIndex:0];
39923973
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3993-
[encoder setBuffer: h_src1 offset:0 atIndex:2];
3974+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
39943975
[encoder setBuffer: h_tpe offset:0 atIndex:3];
39953976
[encoder setBuffer: h_ids offset:0 atIndex:4];
39963977
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7477,9 +7477,7 @@ kernel void kernel_mul_mm(
74777477
template<typename T4>
74787478
kernel void kernel_mul_mm_id_map0(
74797479
constant ggml_metal_kargs_mul_mm_id_map0 & args,
7480-
device const char * src1,
74817480
device const char * src2,
7482-
device char * hsrc1,
74837481
device char * htpe,
74847482
device char * hids,
74857483
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -7491,24 +7489,16 @@ kernel void kernel_mul_mm_id_map0(
74917489

74927490
device int32_t * ids_i32 = (device int32_t *) (hids);
74937491

7494-
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
7492+
for (int i21 = 0; i21 < args.ne21; i21++) { // n_tokens
74957493
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
74967494

74977495
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
74987496
if (src2_i32[i20] != ide) {
74997497
continue;
75007498
}
75017499

7502-
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
7503-
device T4 * hsrc1_tx4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
7504-
7505-
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
7506-
hsrc1_tx4[i00] = (T4) (src1_f32x4[i00]);
7507-
}
7508-
75097500
if (tpitg.x == 0) {
7510-
//ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7511-
ids_i32[ide*args.neh11 + n_all] = i21*args.ne20 + i20;
7501+
ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
75127502
}
75137503

75147504
++n_all;
@@ -7546,6 +7536,7 @@ kernel void kernel_mul_mm_id(
75467536
const int im = tgpig.z; // expert
75477537

75487538
device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7539+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
75497540

75507541
const int neh1 = tpe_i32[im];
75517542

@@ -7571,20 +7562,23 @@ kernel void kernel_mul_mm_id(
75717562

75727563
short il = (tiitg % THREAD_PER_ROW);
75737564

7574-
const int i12 = im%args.neh12;
7575-
const int i13 = im/args.neh12;
7565+
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75767566

7577-
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7567+
const int i11 = (id % args.ne20) % args.ne11;
7568+
const int i12 = (id / args.ne20);
7569+
const int i13 = 0;
7570+
7571+
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
75787572
const short offset1 = il/nl;
75797573

75807574
device const block_q * x = (device const block_q *)(src0
75817575
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
75827576

7583-
device const half * y = (device const half *)(src1
7584-
+ args.nbh13*i13
7585-
+ args.nbh12*i12
7586-
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
7587-
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
7577+
device const float * y = (device const float *)(src1
7578+
+ args.nb13*i13
7579+
+ args.nb12*i12
7580+
+ args.nb11*i11
7581+
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
75887582

75897583
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
75907584
// load data and store to threadgroup memory
@@ -7600,7 +7594,7 @@ kernel void kernel_mul_mm_id(
76007594
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
76017595
}
76027596

7603-
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
7597+
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
76047598

76057599
il = (il + 2 < nl) ? il + 2 : il % 2;
76067600
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7646,13 +7640,11 @@ kernel void kernel_mul_mm_id(
76467640
threadgroup_barrier(mem_flags::mem_threadgroup);
76477641

76487642
if (sgitg == 0) {
7649-
device const int32_t * ids_i32 = (device const int32_t *) (hids);
7650-
76517643
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
76527644
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76537645

7654-
const int idt = id / args.ne20;
76557646
const int ide = id % args.ne20;
7647+
const int idt = id / args.ne20;
76567648

76577649
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
76587650
device float4 * D4 = (device float4 *) D;

0 commit comments

Comments
 (0)