Skip to content

Commit 1d8d83d

Browse files
authored
metal : improve MUL_MAT_ID (ggml-org#15541)
* metal : mul_mm_id remove hdst * metal : remove mul_mm_id hsrc1 * metal : mul_mm_id simplify + add test * metal : opt mul_mm_id map0 * metal : optimize mul_mm_id id gathering * metal : mul/div opt * metal : optimize mul_mm_id_map0 ggml-ci
1 parent c4e9239 commit 1d8d83d

File tree

4 files changed

+170
-218
lines changed

4 files changed

+170
-218
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -320,40 +320,31 @@ typedef struct {
320320
} ggml_metal_kargs_mul_mv_ext;
321321

322322
typedef struct {
323+
int32_t ne02;
323324
int32_t ne10;
324325
int32_t ne11; // n_expert_used (bcast)
325326
uint64_t nb11;
326327
uint64_t nb12;
327-
int32_t neh11; // n_tokens
328-
uint64_t nbh11;
328+
int32_t ne21; // n_tokens
329329
int32_t ne20; // n_expert_used
330330
uint64_t nb21;
331331
} ggml_metal_kargs_mul_mm_id_map0;
332332

333-
typedef struct {
334-
int32_t ne20; // n_expert_used
335-
int32_t neh0;
336-
int32_t neh1;
337-
uint64_t nbh1;
338-
uint64_t nbh2;
339-
int32_t ne0;
340-
uint64_t nb1;
341-
uint64_t nb2;
342-
} ggml_metal_kargs_mul_mm_id_map1;
343-
344333
typedef struct {
345334
int32_t ne00;
346335
int32_t ne02;
347336
uint64_t nb01;
348337
uint64_t nb02;
349338
uint64_t nb03;
350-
int32_t neh12;
351-
uint64_t nbh10;
352-
uint64_t nbh11;
353-
uint64_t nbh12;
354-
uint64_t nbh13;
355-
int32_t neh0;
356-
int32_t neh1;
339+
int32_t ne11;
340+
uint64_t nb10;
341+
uint64_t nb11;
342+
uint64_t nb12;
343+
uint64_t nb13;
344+
int32_t ne20;
345+
int32_t ne21;
346+
int32_t ne0;
347+
int32_t ne1;
357348
int16_t r2;
358349
int16_t r3;
359350
} ggml_metal_kargs_mul_mm_id;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 52 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
398398
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
399399
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
400400
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
401-
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
402-
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
401+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
402+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
403+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
404+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
405+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
406+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
403407
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
404408
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
405409
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1428,8 +1432,12 @@ @implementation GGMLMetalClass
14281432
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14291433
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14301434
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1431-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1432-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1435+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1436+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1437+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1438+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1439+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1440+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14331441
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14341442
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14351443
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
39083916
default: break;
39093917
}
39103918

3911-
const int64_t neh10 = ne10; // n_embd
3912-
const int64_t neh11 = ne21; // n_tokens
3913-
const int64_t neh12 = ne02; // n_expert
3914-
3915-
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
3916-
const uint64_t nbh11 = nbh10*neh10;
3917-
const uint64_t nbh12 = nbh11*neh11;
3918-
const uint64_t nbh13 = nbh12*neh12;
3919-
3920-
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
3921-
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3922-
if (!h_src1) {
3923-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3924-
return 0;
3925-
}
3926-
3927-
const int64_t neh0 = ne0;
3928-
const int64_t neh1 = ne21;
3929-
const int64_t neh2 = ne02;
3930-
3931-
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
3932-
const uint64_t nbh1 = nbh0*neh0;
3933-
const uint64_t nbh2 = nbh1*neh1;
3934-
//const uint64_t nbh3 = nbh2*neh2;
3935-
3936-
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
3937-
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3938-
if (!h_dst) {
3939-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3940-
return 0;
3941-
}
3942-
39433919
// tokens per expert
39443920
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
39453921
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3949,41 +3925,54 @@ static int ggml_metal_encode_node(
39493925
}
39503926

39513927
// id map
3952-
// [n_expert_used, n_tokens]
3953-
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
3928+
// [n_tokens, n_expert]
3929+
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
39543930
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
39553931
if (!h_ids) {
39563932
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
39573933
return 0;
39583934
}
39593935

39603936
{
3961-
const int nth = MIN(1024, ne10/4);
3962-
39633937
ggml_metal_kargs_mul_mm_id_map0 args = {
3938+
ne02,
39643939
ne10,
3965-
ne11, // n_expert_used (bcast)
3940+
ne11, // n_expert_used (bcast)
39663941
nb11,
39673942
nb12,
3968-
neh11, // n_tokens
3969-
nbh11,
3970-
ne20, // n_expert_used
3943+
ne21, // n_tokens
3944+
ne20, // n_expert_used
39713945
nb21,
39723946
};
39733947

39743948
id<MTLComputePipelineState> pipeline = nil;
39753949

3976-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3950+
pipeline = nil;
3951+
3952+
switch (ne20) {
3953+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
3954+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
3955+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
3956+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
3957+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
3958+
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
3959+
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
3960+
}
3961+
3962+
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
3963+
3964+
const size_t smem = ne02*ne20*sizeof(uint16_t);
3965+
3966+
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
39773967

39783968
[encoder setComputePipelineState:pipeline];
39793969
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3980-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3981-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3982-
[encoder setBuffer: h_src1 offset:0 atIndex:3];
3983-
[encoder setBuffer: h_tpe offset:0 atIndex:4];
3984-
[encoder setBuffer: h_ids offset:0 atIndex:5];
3970+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
3971+
[encoder setBuffer: h_tpe offset:0 atIndex:2];
3972+
[encoder setBuffer: h_ids offset:0 atIndex:3];
3973+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
39853974

3986-
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3975+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
39873976
}
39883977

39893978
{
@@ -4022,56 +4011,30 @@ static int ggml_metal_encode_node(
40224011
/*.nb01 =*/ nb01,
40234012
/*.nb02 =*/ nb02,
40244013
/*.nb03 =*/ nb03,
4025-
/*.neh12 =*/ neh12,
4026-
/*.nbh10 =*/ nbh10,
4027-
/*.nbh11 =*/ nbh11,
4028-
/*.nbh12 =*/ nbh12,
4029-
/*.nbh13 =*/ nbh13,
4030-
/*.neh0 =*/ neh0,
4031-
/*.neh1 =*/ neh1,
4014+
/*.ne11 =*/ ne11, // n_expert_used (bcast)
4015+
/*.nb10 =*/ nb10,
4016+
/*.nb11 =*/ nb11,
4017+
/*.nb12 =*/ nb12,
4018+
/*.nb13 =*/ nb13,
4019+
/*.ne20 =*/ ne20, // n_expert_used
4020+
/*.ne21 =*/ ne21, // n_tokens
4021+
/*.ne0 =*/ ne0,
4022+
/*.ne1 =*/ ne1,
40324023
/*.r2 =*/ r2,
40334024
/*.r3 =*/ r3,
40344025
};
40354026

40364027
[encoder setComputePipelineState:pipeline];
40374028
[encoder setBytes:&args length:sizeof(args) atIndex:0];
40384029
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4039-
[encoder setBuffer: h_src1 offset:0 atIndex:2];
4030+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
40404031
[encoder setBuffer: h_tpe offset:0 atIndex:3];
4041-
[encoder setBuffer: h_dst offset:0 atIndex:4];
4032+
[encoder setBuffer: h_ids offset:0 atIndex:4];
4033+
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
40424034

40434035
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
40444036
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
40454037
}
4046-
4047-
{
4048-
GGML_ASSERT(ne0 % 4 == 0);
4049-
4050-
const int nth = MIN(1024, ne0/4);
4051-
4052-
ggml_metal_kargs_mul_mm_id_map1 args = {
4053-
ne20, // n_expert_used
4054-
neh0,
4055-
neh1,
4056-
nbh1,
4057-
nbh2,
4058-
ne0,
4059-
nb1,
4060-
nb2,
4061-
};
4062-
4063-
id<MTLComputePipelineState> pipeline = nil;
4064-
4065-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
4066-
4067-
[encoder setComputePipelineState:pipeline];
4068-
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4069-
[encoder setBuffer: h_dst offset:0 atIndex:1];
4070-
[encoder setBuffer: h_ids offset:0 atIndex:2];
4071-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
4072-
4073-
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4074-
}
40754038
} else {
40764039
id<MTLComputePipelineState> pipeline = nil;
40774040

0 commit comments

Comments
 (0)