@@ -7477,9 +7477,7 @@ kernel void kernel_mul_mm(
74777477template <typename T4>
74787478kernel void kernel_mul_mm_id_map0 (
74797479 constant ggml_metal_kargs_mul_mm_id_map0 & args,
7480- device const char * src1,
74817480 device const char * src2,
7482- device char * hsrc1,
74837481 device char * htpe,
74847482 device char * hids,
74857483 uint3 tgpig[[threadgroup_position_in_grid]],
@@ -7491,24 +7489,16 @@ kernel void kernel_mul_mm_id_map0(
74917489
74927490 device int32_t * ids_i32 = (device int32_t *) (hids);
74937491
7494- for (int i21 = 0 ; i21 < args.neh11 ; i21++) { // n_tokens
7492+ for (int i21 = 0 ; i21 < args.ne21 ; i21++) { // n_tokens
74957493 device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21 );
74967494
74977495 for (int i20 = 0 ; i20 < args.ne20 ; i20++) { // n_expert_used
74987496 if (src2_i32[i20] != ide) {
74997497 continue ;
75007498 }
75017499
7502- device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11 )*args.nb11 );
7503- device T4 * hsrc1_tx4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11 );
7504-
7505- for (int64_t i00 = tpitg.x ; i00 < args.ne10 /4 ; i00 += ntg.x ) {
7506- hsrc1_tx4[i00] = (T4) (src1_f32x4[i00]);
7507- }
7508-
75097500 if (tpitg.x == 0 ) {
7510- // ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
7511- ids_i32[ide*args.neh11 + n_all] = i21*args.ne20 + i20;
7501+ ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
75127502 }
75137503
75147504 ++n_all;
@@ -7546,6 +7536,7 @@ kernel void kernel_mul_mm_id(
75467536 const int im = tgpig.z ; // expert
75477537
75487538 device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7539+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
75497540
75507541 const int neh1 = tpe_i32[im];
75517542
@@ -7571,20 +7562,23 @@ kernel void kernel_mul_mm_id(
75717562
75727563 short il = (tiitg % THREAD_PER_ROW);
75737564
7574- const int i12 = im%args.neh12 ;
7575- const int i13 = im/args.neh12 ;
7565+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75767566
7577- const uint64_t offset0 = (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
7567+ const int i11 = (id % args.ne20 ) % args.ne11 ;
7568+ const int i12 = (id / args.ne20 );
7569+ const int i13 = 0 ;
7570+
7571+ const uint64_t offset0 = im*args.nb02 + i13*args.nb03 ;
75787572 const short offset1 = il/nl;
75797573
75807574 device const block_q * x = (device const block_q *)(src0
75817575 + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
75827576
7583- device const half * y = (device const half *)(src1
7584- + args.nbh13 *i13
7585- + args.nbh12 *i12
7586- + args.nbh11 *(r1*BLOCK_SIZE_N + thread_col)
7587- + args.nbh10 *(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
7577+ device const float * y = (device const float *)(src1
7578+ + args.nb13 *i13
7579+ + args.nb12 *i12
7580+ + args.nb11 *i11
7581+ + args.nb10 *(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
75887582
75897583 for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K) {
75907584 // load data and store to threadgroup memory
@@ -7600,7 +7594,7 @@ kernel void kernel_mul_mm_id(
76007594 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
76017595 }
76027596
7603- *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
7597+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)( *((device float2x4 *) y) );
76047598
76057599 il = (il + 2 < nl) ? il + 2 : il % 2 ;
76067600 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -7646,13 +7640,11 @@ kernel void kernel_mul_mm_id(
76467640 threadgroup_barrier (mem_flags::mem_threadgroup);
76477641
76487642 if (sgitg == 0 ) {
7649- device const int32_t * ids_i32 = (device const int32_t *) (hids);
7650-
76517643 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
76527644 const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76537645
7654- const int idt = id / args.ne20 ;
76557646 const int ide = id % args.ne20 ;
7647+ const int idt = id / args.ne20 ;
76567648
76577649 device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
76587650 device float4 * D4 = (device float4 *) D;
0 commit comments