Skip to content

Commit 6b88434

Browse files
committed
metal : mul_mm_id remove hdst
1 parent b730706 commit 6b88434

File tree

3 files changed

+48
-123
lines changed

3 files changed

+48
-123
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,13 +347,15 @@ typedef struct {
347347
uint64_t nb01;
348348
uint64_t nb02;
349349
uint64_t nb03;
350+
int32_t ne20;
351+
int32_t ne21;
350352
int32_t neh12;
351353
uint64_t nbh10;
352354
uint64_t nbh11;
353355
uint64_t nbh12;
354356
uint64_t nbh13;
355-
int32_t neh0;
356-
int32_t neh1;
357+
int32_t ne0;
358+
int32_t ne1;
357359
int16_t r2;
358360
int16_t r3;
359361
} ggml_metal_kargs_mul_mm_id;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
397397
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
398398
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
399399
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
400-
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
401400
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
402401
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
403402
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1413,7 +1412,6 @@ @implementation GGMLMetalClass
14131412
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14141413
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
14151414
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1416-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
14171415
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14181416
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14191417
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3894,22 +3892,6 @@ static int ggml_metal_encode_node(
38943892
return 0;
38953893
}
38963894

3897-
const int64_t neh0 = ne0;
3898-
const int64_t neh1 = ne21;
3899-
const int64_t neh2 = ne02;
3900-
3901-
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
3902-
const uint64_t nbh1 = nbh0*neh0;
3903-
const uint64_t nbh2 = nbh1*neh1;
3904-
//const uint64_t nbh3 = nbh2*neh2;
3905-
3906-
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
3907-
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3908-
if (!h_dst) {
3909-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3910-
return 0;
3911-
}
3912-
39133895
// tokens per expert
39143896
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
39153897
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3919,8 +3901,8 @@ static int ggml_metal_encode_node(
39193901
}
39203902

39213903
// id map
3922-
// [n_expert_used, n_tokens]
3923-
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
3904+
// [n_tokens, n_expert]
3905+
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
39243906
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
39253907
if (!h_ids) {
39263908
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@@ -3992,13 +3974,15 @@ static int ggml_metal_encode_node(
39923974
/*.nb01 =*/ nb01,
39933975
/*.nb02 =*/ nb02,
39943976
/*.nb03 =*/ nb03,
3977+
/*.ne20 =*/ ne20, // n_expert_used
3978+
/*.ne21 =*/ ne21, // n_tokens
39953979
/*.neh12 =*/ neh12,
39963980
/*.nbh10 =*/ nbh10,
39973981
/*.nbh11 =*/ nbh11,
39983982
/*.nbh12 =*/ nbh12,
39993983
/*.nbh13 =*/ nbh13,
4000-
/*.neh0 =*/ neh0,
4001-
/*.neh1 =*/ neh1,
3984+
/*.ne0 =*/ ne0,
3985+
/*.ne1 =*/ ne1,
40023986
/*.r2 =*/ r2,
40033987
/*.r3 =*/ r3,
40043988
};
@@ -4008,40 +3992,12 @@ static int ggml_metal_encode_node(
40083992
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
40093993
[encoder setBuffer: h_src1 offset:0 atIndex:2];
40103994
[encoder setBuffer: h_tpe offset:0 atIndex:3];
4011-
[encoder setBuffer: h_dst offset:0 atIndex:4];
3995+
[encoder setBuffer: h_ids offset:0 atIndex:4];
3996+
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
40123997

40133998
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
40143999
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
40154000
}
4016-
4017-
{
4018-
GGML_ASSERT(ne0 % 4 == 0);
4019-
4020-
const int nth = MIN(1024, ne0/4);
4021-
4022-
ggml_metal_kargs_mul_mm_id_map1 args = {
4023-
ne20, // n_expert_used
4024-
neh0,
4025-
neh1,
4026-
nbh1,
4027-
nbh2,
4028-
ne0,
4029-
nb1,
4030-
nb2,
4031-
};
4032-
4033-
id<MTLComputePipelineState> pipeline = nil;
4034-
4035-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
4036-
4037-
[encoder setComputePipelineState:pipeline];
4038-
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4039-
[encoder setBuffer: h_dst offset:0 atIndex:1];
4040-
[encoder setBuffer: h_ids offset:0 atIndex:2];
4041-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
4042-
4043-
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4044-
}
40454001
} else {
40464002
id<MTLComputePipelineState> pipeline = nil;
40474003

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 36 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7500,14 +7500,15 @@ kernel void kernel_mul_mm_id_map0(
75007500
}
75017501

75027502
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
7503-
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
7503+
device T4 * hsrc1_tx4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
75047504

75057505
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
7506-
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
7506+
hsrc1_tx4[i00] = (T4) (src1_f32x4[i00]);
75077507
}
75087508

75097509
if (tpitg.x == 0) {
7510-
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7510+
//ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7511+
ids_i32[ide*args.neh11 + n_all] = i21*args.ne20 + i20;
75117512
}
75127513

75137514
++n_all;
@@ -7524,43 +7525,13 @@ typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
75247525

75257526
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
75267527

7527-
template<typename T>
7528-
kernel void kernel_mul_mm_id_map1(
7529-
constant ggml_metal_kargs_mul_mm_id_map1 & args,
7530-
device const char * hdst,
7531-
device const char * hids,
7532-
device char * dst,
7533-
uint3 tgpig[[threadgroup_position_in_grid]],
7534-
ushort3 tpitg[[thread_position_in_threadgroup]],
7535-
ushort3 ntg[[threads_per_threadgroup]]) {
7536-
const int i20 = tgpig[0]; // used expert
7537-
const int i21 = tgpig[1]; // token
7538-
7539-
device const int32_t * ids_i32 = (device const int32_t *) (hids);
7540-
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
7541-
7542-
const int id = ids_i32[i21*args.ne20 + i20];
7543-
7544-
const int ide = id / args.neh1;
7545-
const int idt = id % args.neh1;
7546-
7547-
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
7548-
7549-
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
7550-
dst_f32x4[i0] = hdst_f32x4[i0];
7551-
}
7552-
}
7553-
7554-
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
7555-
7556-
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
7557-
75587528
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
75597529
kernel void kernel_mul_mm_id(
75607530
constant ggml_metal_kargs_mul_mm_id & args,
75617531
device const char * src0,
75627532
device const char * src1,
7563-
device const char * tpe,
7533+
device const char * htpe,
7534+
device const char * hids,
75647535
device char * dst,
75657536
threadgroup char * shmem [[threadgroup(0)]],
75667537
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -7572,9 +7543,9 @@ kernel void kernel_mul_mm_id(
75727543

75737544
const int r0 = tgpig.y;
75747545
const int r1 = tgpig.x;
7575-
const int im = tgpig.z;
7546+
const int im = tgpig.z; // expert
75767547

7577-
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
7548+
device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
75787549

75797550
const int neh1 = tpe_i32[im];
75807551

@@ -7583,8 +7554,8 @@ kernel void kernel_mul_mm_id(
75837554
}
75847555

75857556
// if this block is of 64x32 shape or smaller
7586-
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
7587-
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
7557+
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
7558+
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
75887559

75897560
// a thread shouldn't load data outside of the matrix
75907561
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -7665,42 +7636,38 @@ kernel void kernel_mul_mm_id(
76657636
}
76667637
}
76677638

7668-
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
7669-
device float * C = (device float *) dst +
7670-
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
7671-
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
7639+
threadgroup_barrier(mem_flags::mem_threadgroup);
7640+
threadgroup float * temp_str = ((threadgroup float *) shmem) \
7641+
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
7642+
for (short i = 0; i < 8; i++) {
7643+
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
7644+
}
76727645

7673-
for (short i = 0; i < 8; i++) {
7674-
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
7675-
}
7676-
} else {
7677-
// block is smaller than 64x32, we should avoid writing data outside of the matrix
7678-
threadgroup_barrier(mem_flags::mem_threadgroup);
7679-
threadgroup float * temp_str = ((threadgroup float *) shmem) \
7680-
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
7681-
for (short i = 0; i < 8; i++) {
7682-
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
7683-
}
7646+
threadgroup_barrier(mem_flags::mem_threadgroup);
76847647

7685-
threadgroup_barrier(mem_flags::mem_threadgroup);
7648+
if (sgitg == 0) {
7649+
device const int32_t * ids_i32 = (device const int32_t *) (hids);
76867650

7687-
if (sgitg == 0) {
7688-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7689-
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
7690-
device float4 * D4 = (device float4 *) D;
7651+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7652+
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76917653

7692-
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7693-
threadgroup float4 * C4 = (threadgroup float4 *) C;
7654+
const int idt = id / args.ne20;
7655+
const int ide = id % args.ne20;
76947656

7695-
int i = 0;
7696-
for (; i < n_rows/4; i++) {
7697-
*(D4 + i) = *(C4 + i);
7698-
}
7657+
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
7658+
device float4 * D4 = (device float4 *) D;
76997659

7700-
i *= 4;
7701-
for (; i < n_rows; i++) {
7702-
*(D + i) = *(C + i);
7703-
}
7660+
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7661+
threadgroup float4 * C4 = (threadgroup float4 *) C;
7662+
7663+
int i = 0;
7664+
for (; i < n_rows/4; i++) {
7665+
*(D4 + i) = *(C4 + i);
7666+
}
7667+
7668+
i *= 4;
7669+
for (; i < n_rows; i++) {
7670+
*(D + i) = *(C + i);
77047671
}
77057672
}
77067673
}

0 commit comments

Comments
 (0)