Skip to content

Commit dfc72fc

Browse files
committed
metal : optimize mul_mm_id_map0
ggml-ci
1 parent 849d944 commit dfc72fc

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
396396
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
397397
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
398398
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
399-
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
399+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
400+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
401+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
402+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
403+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
404+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
400405
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
401406
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
402407
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1411,7 +1416,12 @@ @implementation GGMLMetalClass
14111416
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14121417
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14131418
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1414-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1419+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1420+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1421+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1422+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1423+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1424+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14151425
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14161426
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14171427
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3907,7 +3917,17 @@ static int ggml_metal_encode_node(
39073917

39083918
id<MTLComputePipelineState> pipeline = nil;
39093919

3910-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
3920+
pipeline = nil;
3921+
3922+
switch (ne20) {
3923+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
3924+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
3925+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
3926+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
3927+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
3928+
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
3929+
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
3930+
}
39113931

39123932
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
39133933

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7488,7 +7488,7 @@ kernel void kernel_mul_mm(
74887488
}
74897489
}
74907490

7491-
template<typename T4>
7491+
template<short ne20> // n_expert_used
74927492
kernel void kernel_mul_mm_id_map0(
74937493
constant ggml_metal_kargs_mul_mm_id_map0 & args,
74947494
device const char * src2,
@@ -7501,31 +7501,38 @@ kernel void kernel_mul_mm_id_map0(
75017501

75027502
uint32_t n_all = 0;
75037503

7504-
device int32_t * ids_i32 = (device int32_t *) (hids);
7504+
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
75057505

75067506
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
7507-
{
7507+
if (i21 + tpitg < args.ne21) {
75087508
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
75097509

7510-
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20;
7510+
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
75117511

7512-
for (int i20 = 0; i20 < args.ne20 && i21 + tpitg < args.ne21; i20++) {
7512+
#pragma unroll(ne20)
7513+
for (short i20 = 0; i20 < ne20; i20++) {
75137514
sids[i20] = src2_i32[i20];
75147515
}
75157516
}
75167517

75177518
threadgroup_barrier(mem_flags::mem_threadgroup);
75187519

7519-
for (int t = 0; t < ntg && i21 + t < args.ne21; t++) {
7520-
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20;
7520+
for (short t = 0; t < ntg; t++) {
7521+
if (i21 + t >= args.ne21) {
7522+
break;
7523+
}
75217524

7522-
for (int i20 = 0; i20 < args.ne20; i20++) {
7523-
if (sids[i20] == ide) {
7524-
ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7525-
++n_all;
7526-
break;
7527-
}
7525+
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
7526+
7527+
short sel = 0;
7528+
#pragma unroll(ne20)
7529+
for (short i20 = 0; i20 < ne20; i20++) {
7530+
sel += (sids[i20] == ide)*(i20 + 1);
75287531
}
7532+
7533+
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
7534+
7535+
n_all += sel > 0;
75297536
}
75307537

75317538
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -7535,9 +7542,14 @@ kernel void kernel_mul_mm_id_map0(
75357542
tpe_u32[ide] = n_all;
75367543
}
75377544

7538-
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
7545+
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
75397546

7540-
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
7547+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
7548+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
7549+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
7550+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
7551+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
7552+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
75417553

75427554
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
75437555
kernel void kernel_mul_mm_id(
@@ -7563,7 +7575,7 @@ kernel void kernel_mul_mm_id(
75637575
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
75647576
device const int32_t * ids_i32 = (device const int32_t *) (hids);
75657577

7566-
const uint32_t neh1 = tpe_u32[im];
7578+
const int32_t neh1 = tpe_u32[im];
75677579

75687580
if (r1*BLOCK_SIZE_N >= neh1) {
75697581
return;

0 commit comments

Comments
 (0)