@@ -7480,31 +7480,45 @@ kernel void kernel_mul_mm_id_map0(
74807480 device const char * src2,
74817481 device char * htpe,
74827482 device char * hids,
7483- uint3 tgpig[[threadgroup_position_in_grid ]],
7484- ushort3 tpitg[[thread_position_in_threadgroup]],
7485- ushort3 ntg[[threads_per_threadgroup]]) {
7486- const int ide = tgpig[ 0 ] ; // expert id
7483+ threadgroup char * shmem [[threadgroup( 0 ) ]],
7484+ ushort tpitg[[thread_position_in_threadgroup]],
7485+ ushort ntg[[threads_per_threadgroup]]) {
7486+ const short ide = tpitg ; // expert id
74877487
7488- int n_all = 0 ;
7488+ uint32_t n_all = 0 ;
74897489
74907490 device int32_t * ids_i32 = (device int32_t *) (hids);
74917491
7492- for (int i21 = 0 ; i21 < args.ne21 ; i21++) { // n_tokens
7493- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21 );
7492+ for (int i21 = 0 ; i21 < args.ne21 ; i21 += ntg) { // n_tokens
7493+ {
7494+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21 );
74947495
7495- for (int i20 = 0 ; i20 < args.ne20 ; i20++) { // n_expert_used
7496- if (src2_i32[i20] != ide) {
7497- continue ;
7496+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*args.ne20 ;
7497+
7498+ for (int i20 = 0 ; i20 < args.ne20 && i21 + tpitg < args.ne21 ; i20++) {
7499+ sids[i20] = src2_i32[i20];
74987500 }
7501+ }
7502+
7503+ threadgroup_barrier (mem_flags::mem_threadgroup);
74997504
7500- ids_i32[ide*args.ne21 + n_all] = i21*args.ne20 + i20;
7505+ for (int t = 0 ; t < ntg && i21 + t < args.ne21 ; t++) {
7506+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + t*args.ne20 ;
75017507
7502- ++n_all;
7508+ for (int i20 = 0 ; i20 < args.ne20 ; i20++) {
7509+ if (sids[i20] == ide) {
7510+ ids_i32[ide*args.ne21 + n_all] = (i21 + t)*args.ne20 + i20;
7511+ ++n_all;
7512+ break ;
7513+ }
7514+ }
75037515 }
7516+
7517+ threadgroup_barrier (mem_flags::mem_threadgroup);
75047518 }
75057519
7506- device int32_t * tpe_i32 = (device int32_t *) (htpe);
7507- tpe_i32 [ide] = n_all;
7520+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
7521+ tpe_u32 [ide] = n_all;
75087522}
75097523
75107524typedef decltype (kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
@@ -7532,10 +7546,10 @@ kernel void kernel_mul_mm_id(
75327546 const int r1 = tgpig.x ;
75337547 const int im = tgpig.z ; // expert
75347548
7535- device const int32_t * tpe_i32 = (device const int32_t *) (htpe);
7536- device const int32_t * ids_i32 = (device const int32_t *) (hids);
7549+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
7550+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
75377551
7538- const int neh1 = tpe_i32 [im];
7552+ const uint32_t neh1 = tpe_u32 [im];
75397553
75407554 if (r1*BLOCK_SIZE_N >= neh1) {
75417555 return ;
@@ -7561,9 +7575,9 @@ kernel void kernel_mul_mm_id(
75617575
75627576 const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
75637577
7564- const int i11 = (id % args.ne20 ) % args.ne11 ;
7565- const int i12 = (id / args.ne20 );
7566- const int i13 = 0 ;
7578+ const short i11 = (id % args.ne20 ) % args.ne11 ;
7579+ const short i12 = (id / args.ne20 );
7580+ const short i13 = 0 ;
75677581
75687582 const uint64_t offset0 = im*args.nb02 + i13*args.nb03 ;
75697583 const short offset1 = il/nl;
@@ -7632,17 +7646,18 @@ kernel void kernel_mul_mm_id(
76327646 threadgroup float * temp_str = ((threadgroup float *) shmem) \
76337647 + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
76347648
7649+ #pragma unroll(8)
76357650 for (short i = 0 ; i < 8 ; i++) {
76367651 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
76377652 }
76387653
76397654 threadgroup_barrier (mem_flags::mem_threadgroup);
76407655
7641- for (int j = sgitg; j < n_cols; j += 4 ) {
7656+ for (short j = sgitg; j < n_cols; j += 4 ) {
76427657 const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76437658
7644- const int ide = id % args.ne20 ;
7645- const int idt = id / args.ne20 ;
7659+ const short ide = id % args.ne20 ;
7660+ const short idt = id / args.ne20 ;
76467661
76477662 device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
76487663 device float4 * D4 = (device float4 *) D;
0 commit comments