@@ -6295,8 +6295,8 @@ kernel void kernel_mul_mm(device const  uchar * src0,
62956295                          uint                  tiitg[[thread_index_in_threadgroup]],
62966296                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
62976297
6298-     threadgroup T      * sa = (threadgroup T      *)(shared_memory);
6299-     threadgroup float  * sb = (threadgroup float  *)(shared_memory + 4096 );
6298+     threadgroup T    * sa = (threadgroup T    *)(shared_memory);
6299+     threadgroup half  * sb = (threadgroup half  *)(shared_memory + 4096 );
63006300
63016301    const  uint r0 = tgpig.y ;
63026302    const  uint r1 = tgpig.x ;
@@ -6310,11 +6310,11 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63106310    short  thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
63116311    short  thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
63126312
6313-     simdgroup_T8x8      ma[4 ];
6314-     simdgroup_float8x8  mb[2 ];
6315-     simdgroup_float8x8 c_res [8 ];
6313+     simdgroup_T8x8    ma[4 ];
6314+     simdgroup_half8x8  mb[2 ];
6315+     simdgroup_half8x8 mc [8 ];
63166316    for  (int  i = 0 ; i < 8 ; i++){
6317-         c_res [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6317+         mc [i] = make_filled_simdgroup_matrix<half , 8 >(0 .h );
63186318    }
63196319
63206320    short  il = (tiitg % THREAD_PER_ROW);
@@ -6345,17 +6345,17 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63456345            +                     (tiitg / THREAD_PER_ROW) % 8   + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
63466346        }
63476347
6348-         *(threadgroup float2x4  *)(sb + (tiitg % THREAD_PER_COL) * 8  * 32  + 8  * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6348+         *(threadgroup half2x4  *)(sb + (tiitg % THREAD_PER_COL) * 8  * 32  + 8  * (tiitg / THREAD_PER_COL)) = (half2x4)( *((device float2x4 *)y) );
63496349
63506350        il = (il + 2  < nl) ? il + 2  : il % 2 ;
6351-         x  = (il < 2 ) ? x + (2 +nl- 1 )/nl : x;
6351+         x  = (il <   2 ) ? x + (2  + nl -  1 )/nl : x;
63526352        y += BLOCK_SIZE_K;
63536353
63546354        threadgroup_barrier (mem_flags::mem_threadgroup);
63556355
63566356        //  load matrices from threadgroup memory and conduct outer products
6357-         threadgroup T      * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6358-         threadgroup float  * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6357+         threadgroup T    * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6358+         threadgroup half  * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
63596359
63606360        #pragma  unroll(4)
63616361        for  (int  ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
@@ -6374,7 +6374,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63746374
63756375            #pragma  unroll(8)
63766376            for  (int  i = 0 ; i < 8 ; i++){
6377-                 simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
6377+                 simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
63786378            }
63796379        }
63806380    }
@@ -6383,15 +6383,22 @@ kernel void kernel_mul_mm(device const  uchar * src0,
63836383        device float  * C = dst + (BLOCK_SIZE_M * r0 + 32  * (sgitg &  1 )) \
63846384                               + (BLOCK_SIZE_N * r1 + 16  * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
63856385        for  (int  i = 0 ; i < 8 ; i++) {
6386-             simdgroup_store (c_res[i], C + 8  * (i%4 ) + 8  * ne0 * (i/4 ), ne0);
6386+             //  cast to f32
6387+             simdgroup_float8x8 mc_f32 (1 .0f );
6388+             simdgroup_multiply (mc_f32, mc[i], mc_f32);
6389+             simdgroup_store (mc_f32, C + 8  * (i%4 ) + 8  * ne0 * (i/4 ), ne0);
6390+             // simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
63876391        }
63886392    } else  {
63896393        //  block is smaller than 64x32, we should avoid writing data outside of the matrix
63906394        threadgroup_barrier (mem_flags::mem_threadgroup);
63916395        threadgroup float  * temp_str = ((threadgroup float  *)shared_memory) \
6392-                                       + 32  * (sgitg&1 ) + (16  * (sgitg>>1 )) * BLOCK_SIZE_M;
6396+                                         + 32  * (sgitg&1 ) + (16  * (sgitg>>1 )) * BLOCK_SIZE_M;
63936397        for  (int  i = 0 ; i < 8 ; i++) {
6394-             simdgroup_store (c_res[i], temp_str + 8  * (i%4 ) + 8  * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6398+             simdgroup_float8x8 mc_f32 (1 .0f );
6399+             simdgroup_multiply (mc_f32, mc[i], mc_f32);
6400+             simdgroup_store (mc_f32, temp_str + 8  * (i%4 ) + 8  * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6401+             // simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
63956402        }
63966403
63976404        threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments