@@ -7500,14 +7500,15 @@ kernel void kernel_mul_mm_id_map0(
75007500 }
75017501
75027502 device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11 )*args.nb11 );
7503- device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11 );
7503+ device T4 * hsrc1_tx4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11 );
75047504
75057505 for (int64_t i00 = tpitg.x ; i00 < args.ne10 /4 ; i00 += ntg.x ) {
7506- hsrc1_f32x4 [i00] = (T4) (src1_f32x4[i00]);
7506+ hsrc1_tx4 [i00] = (T4) (src1_f32x4[i00]);
75077507 }
75087508
75097509 if (tpitg.x == 0 ) {
7510- ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7510+ // ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7511+ ids_i32[ide*args.neh11 + n_all] = i21*args.ne20 + i20;
75117512 }
75127513
75137514 ++n_all;
@@ -7524,43 +7525,13 @@ typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
75247525
75257526template [[host_name(" kernel_mul_mm_id_map0_f16" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
75267527
7527- template <typename T>
7528- kernel void kernel_mul_mm_id_map1 (
7529- constant ggml_metal_kargs_mul_mm_id_map1 & args,
7530- device const char * hdst,
7531- device const char * hids,
7532- device char * dst,
7533- uint3 tgpig[[threadgroup_position_in_grid]],
7534- ushort3 tpitg[[thread_position_in_threadgroup]],
7535- ushort3 ntg[[threads_per_threadgroup]]) {
7536- const int i20 = tgpig[0 ]; // used expert
7537- const int i21 = tgpig[1 ]; // token
7538-
7539- device const int32_t * ids_i32 = (device const int32_t *) (hids);
7540- device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2 );
7541-
7542- const int id = ids_i32[i21*args.ne20 + i20];
7543-
7544- const int ide = id / args.neh1 ;
7545- const int idt = id % args.neh1 ;
7546-
7547- device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2 );
7548-
7549- for (int64_t i0 = tpitg.x ; i0 < args.neh0 /4 ; i0 += ntg.x ) {
7550- dst_f32x4[i0] = hdst_f32x4[i0];
7551- }
7552- }
7553-
7554- typedef decltype (kernel_mul_mm_id_map1<float >) kernel_mul_mm_id_map1_t;
7555-
7556- template [[host_name(" kernel_mul_mm_id_map1_f32" )]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float >;
7557-
75587528template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
75597529kernel void kernel_mul_mm_id (
75607530 constant ggml_metal_kargs_mul_mm_id & args,
75617531 device const char * src0,
75627532 device const char * src1,
7563- device const char * tpe,
7533+ device const char * htpe,
7534+ device const char * hids,
75647535 device char * dst,
75657536 threadgroup char * shmem [[threadgroup(0 )]],
75667537 uint3 tgpig[[threadgroup_position_in_grid]],
@@ -7572,9 +7543,9 @@ kernel void kernel_mul_mm_id(
75727543
75737544 const int r0 = tgpig.y ;
75747545 const int r1 = tgpig.x ;
7575- const int im = tgpig.z ;
7546+ const int im = tgpig.z ; // expert
75767547
7577- device const int32_t * tpe_i32 = (device const int32_t *) (tpe );
7548+ device const int32_t * tpe_i32 = (device const int32_t *) (htpe );
75787549
75797550 const int neh1 = tpe_i32[im];
75807551
@@ -7583,8 +7554,8 @@ kernel void kernel_mul_mm_id(
75837554 }
75847555
75857556 // if this block is of 64x32 shape or smaller
7586- const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
7587- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
7557+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
7558+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
75887559
75897560 // a thread shouldn't load data outside of the matrix
75907561 const short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
@@ -7665,42 +7636,38 @@ kernel void kernel_mul_mm_id(
76657636 }
76667637 }
76677638
7668- if ((r0 + 1 ) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1 ) * BLOCK_SIZE_N <= neh1) {
7669- device float * C = (device float *) dst +
7670- (BLOCK_SIZE_M * r0 + 32 *(sgitg & 1 )) + \
7671- (BLOCK_SIZE_N * r1 + 16 *(sgitg >> 1 )) * args.neh0 + im*args.neh1 *args.neh0 ;
7639+ threadgroup_barrier (mem_flags::mem_threadgroup);
7640+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
7641+ + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
7642+ for (short i = 0 ; i < 8 ; i++) {
7643+ simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
7644+ }
76727645
7673- for (short i = 0 ; i < 8 ; i++) {
7674- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.neh0 * (i/4 ), args.neh0 );
7675- }
7676- } else {
7677- // block is smaller than 64x32, we should avoid writing data outside of the matrix
7678- threadgroup_barrier (mem_flags::mem_threadgroup);
7679- threadgroup float * temp_str = ((threadgroup float *) shmem) \
7680- + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
7681- for (short i = 0 ; i < 8 ; i++) {
7682- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
7683- }
7646+ threadgroup_barrier (mem_flags::mem_threadgroup);
76847647
7685- threadgroup_barrier (mem_flags::mem_threadgroup);
7648+ if (sgitg == 0 ) {
7649+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
76867650
7687- if (sgitg == 0 ) {
7688- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7689- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1 *args.neh0 ;
7690- device float4 * D4 = (device float4 *) D;
7651+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7652+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76917653
7692- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M) ;
7693- threadgroup float4 * C4 = (threadgroup float4 *) C ;
7654+ const int idt = id / args. ne20 ;
7655+ const int ide = id % args. ne20 ;
76947656
7695- int i = 0 ;
7696- for (; i < n_rows/4 ; i++) {
7697- *(D4 + i) = *(C4 + i);
7698- }
7657+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
7658+ device float4 * D4 = (device float4 *) D;
76997659
7700- i *= 4 ;
7701- for (; i < n_rows; i++) {
7702- *(D + i) = *(C + i);
7703- }
7660+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7661+ threadgroup float4 * C4 = (threadgroup float4 *) C;
7662+
7663+ int i = 0 ;
7664+ for (; i < n_rows/4 ; i++) {
7665+ *(D4 + i) = *(C4 + i);
7666+ }
7667+
7668+ i *= 4 ;
7669+ for (; i < n_rows; i++) {
7670+ *(D + i) = *(C + i);
77047671 }
77057672 }
77067673 }
0 commit comments