@@ -6317,18 +6317,19 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63176317    const  uint im = tgpig.z ;
63186318
63196319    //  if this block is of 64x32 shape or smaller
6320-     short  n_rows = (ne0 - r0 *  BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 *  BLOCK_SIZE_M) : BLOCK_SIZE_M;
6321-     short  n_cols = (ne1 - r1 *  BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 *  BLOCK_SIZE_N) : BLOCK_SIZE_N;
6320+     short  n_rows = (ne0 - r0* BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0* BLOCK_SIZE_M) : BLOCK_SIZE_M;
6321+     short  n_cols = (ne1 - r1* BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1* BLOCK_SIZE_N) : BLOCK_SIZE_N;
63226322
63236323    //  a thread shouldn't load data outside of the matrix
63246324    short  thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
63256325    short  thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
63266326
63276327    simdgroup_T8x8     ma[4 ];
63286328    simdgroup_float8x8 mb[2 ];
6329-     simdgroup_float8x8 c_res[8 ];
6330-     for  (int  i = 0 ; i < 8 ; i++){
6331-         c_res[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6329+     simdgroup_float8x8 mc[8 ];
6330+ 
6331+     for  (short  i = 0 ; i < 8 ; i++){
6332+         mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
63326333    }
63336334
63346335    short  il = (tiitg % THREAD_PER_ROW);
@@ -6339,7 +6340,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63396340    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
63406341    ushort offset1 = il/nl;
63416342
6342-     device const  block_q * x = (device const  block_q *)(src0 + (r0 *  BLOCK_SIZE_M + thread_row) *  nb01 + offset0) + offset1;
6343+     device const  block_q * x = (device const  block_q *)(src0 + (r0* BLOCK_SIZE_M + thread_row)* nb01 + offset0) + offset1;
63436344    device const  float    * y = (device const  float    *)(src1
63446345        + nb13 * i13
63456346        + nb12 * i12
@@ -6353,13 +6354,13 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63536354        threadgroup_barrier (mem_flags::mem_threadgroup);
63546355
63556356        #pragma  unroll(16)
6356-         for  (int  i = 0 ; i < 16 ; i++) {
6357-             *(sa + SG_MAT_SIZE * ((tiitg /  THREAD_PER_ROW /  8 ) \
6358-             +                     (tiitg %  THREAD_PER_ROW) *  16  + (i /  8 ) *  8 ) \
6359-             +                     (tiitg /  THREAD_PER_ROW) %  8   + (i &  7 ) *  8 ) = temp_a[i/4 ][i%4 ];
6357+         for  (short  i = 0 ; i < 16 ; i++) {
6358+             *(sa + SG_MAT_SIZE * ((tiitg/ THREAD_PER_ROW/ 8 ) \
6359+             +                     (tiitg% THREAD_PER_ROW)* 16  + (i/ 8 )* 8 ) \
6360+             +                     (tiitg/ THREAD_PER_ROW)% 8   + (i& 7 )* 8 ) = temp_a[i/4 ][i%4 ];
63606361        }
63616362
6362-         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) *  8  *  32  + 8  *  (tiitg /  THREAD_PER_COL)) = *((device float2x4 *)y);
6363+         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)* 8 * 32  + 8 * (tiitg/ THREAD_PER_COL)) = *((device float2x4 *)  y);
63636364
63646365        il = (il + 2  < nl) ? il + 2  : il % 2 ;
63656366        x  = (il < 2 ) ? x + (2 +nl-1 )/nl : x;
@@ -6368,44 +6369,44 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63686369        threadgroup_barrier (mem_flags::mem_threadgroup);
63696370
63706371        //  load matrices from threadgroup memory and conduct outer products
6371-         threadgroup T     * lsma = (sa + THREAD_MAT_M *  SG_MAT_SIZE *  (sgitg %  2 ));
6372-         threadgroup float  * lsmb = (sb + THREAD_MAT_N *  SG_MAT_SIZE *  (sgitg /  2 ));
6372+         threadgroup T     * lsma = (sa + THREAD_MAT_M* SG_MAT_SIZE* (sgitg% 2 ));
6373+         threadgroup float  * lsmb = (sb + THREAD_MAT_N* SG_MAT_SIZE* (sgitg/ 2 ));
63736374
63746375        #pragma  unroll(4)
6375-         for  (int  ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6376+         for  (short  ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
63766377            #pragma  unroll(4)
6377-             for  (int  i = 0 ; i < 4 ; i++) {
6378-                 simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
6378+             for  (short  i = 0 ; i < 4 ; i++) {
6379+                 simdgroup_load (ma[i],  lsma + SG_MAT_SIZE * i);
63796380            }
63806381            simdgroup_barrier (mem_flags::mem_none);
63816382            #pragma  unroll(2)
6382-             for  (int  i = 0 ; i < 2 ; i++) {
6383-                 simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
6383+             for  (short  i = 0 ; i < 2 ; i++) {
6384+                 simdgroup_load (mb[i],  lsmb + SG_MAT_SIZE * i);
63846385            }
63856386
6386-             lsma += BLOCK_SIZE_M /  SG_MAT_ROW * SG_MAT_SIZE;
6387-             lsmb += BLOCK_SIZE_N /  SG_MAT_ROW * SG_MAT_SIZE;
6387+             lsma += BLOCK_SIZE_M/ SG_MAT_ROW * SG_MAT_SIZE;
6388+             lsmb += BLOCK_SIZE_N/ SG_MAT_ROW * SG_MAT_SIZE;
63886389
63896390            #pragma  unroll(8)
6390-             for  (int  i = 0 ; i < 8 ; i++){
6391-                 simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
6391+             for  (short  i = 0 ; i < 8 ; i++){
6392+                 simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
63926393            }
63936394        }
63946395    }
63956396
63966397    if  ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
63976398        device float  * C = dst + (BLOCK_SIZE_M * r0 + 32  * (sgitg &  1 )) \
63986399                               + (BLOCK_SIZE_N * r1 + 16  * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6399-         for  (int  i = 0 ; i < 8 ; i++) {
6400-             simdgroup_store (c_res [i], C + 8  * (i%4 ) + 8  * ne0 * (i/4 ), ne0);
6400+         for  (short  i = 0 ; i < 8 ; i++) {
6401+             simdgroup_store (mc [i], C + 8  * (i%4 ) + 8  * ne0 * (i/4 ), ne0);
64016402        }
64026403    } else  {
64036404        //  block is smaller than 64x32, we should avoid writing data outside of the matrix
64046405        threadgroup_barrier (mem_flags::mem_threadgroup);
6405-         threadgroup float  * temp_str = ((threadgroup float  *)shared_memory) \
6406-                                       + 32  * (sgitg&1 ) + (16  * (sgitg>>1 )) *  BLOCK_SIZE_M;
6407-         for  (int  i = 0 ; i < 8 ; i++) {
6408-             simdgroup_store (c_res [i], temp_str + 8  *  (i%4 ) + 8  *  BLOCK_SIZE_M *  (i/4 ), BLOCK_SIZE_M);
6406+         threadgroup float  * temp_str = ((threadgroup float  *)  shared_memory) \
6407+                                       + 32  * (sgitg&1 ) + (16  * (sgitg>>1 ))* BLOCK_SIZE_M;
6408+         for  (short  i = 0 ; i < 8 ; i++) {
6409+             simdgroup_store (mc [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M* (i/4 ), BLOCK_SIZE_M);
64096410        }
64106411
64116412        threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments