@@ -6318,18 +6318,19 @@ kernel void kernel_mul_mm(device const uchar * src0,
63186318 const uint im = tgpig.z ;
63196319
63206320 // if this block is of 64x32 shape or smaller
6321- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
6322- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
6321+ short n_rows = (ne0 - r0* BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0* BLOCK_SIZE_M) : BLOCK_SIZE_M;
6322+ short n_cols = (ne1 - r1* BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1* BLOCK_SIZE_N) : BLOCK_SIZE_N;
63236323
63246324 // a thread shouldn't load data outside of the matrix
63256325 short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
63266326 short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
63276327
63286328 simdgroup_T8x8 ma[4 ];
63296329 simdgroup_float8x8 mb[2 ];
6330- simdgroup_float8x8 c_res[8 ];
6331- for (int i = 0 ; i < 8 ; i++){
6332- c_res[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6330+ simdgroup_float8x8 mc[8 ];
6331+
6332+ for (short i = 0 ; i < 8 ; i++){
6333+ mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
63336334 }
63346335
63356336 short il = (tiitg % THREAD_PER_ROW);
@@ -6340,7 +6341,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63406341 uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
63416342 ushort offset1 = il/nl;
63426343
6343- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
6344+ device const block_q * x = (device const block_q *)(src0 + (r0* BLOCK_SIZE_M + thread_row)* nb01 + offset0) + offset1;
63446345 device const float * y = (device const float *)(src1
63456346 + nb13 * i13
63466347 + nb12 * i12
@@ -6354,13 +6355,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
63546355 threadgroup_barrier (mem_flags::mem_threadgroup);
63556356
63566357 #pragma unroll(16)
6357- for (int i = 0 ; i < 16 ; i++) {
6358- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
6359- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8 ) * 8 ) \
6360- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
6358+ for (short i = 0 ; i < 16 ; i++) {
6359+ *(sa + SG_MAT_SIZE * ((tiitg/ THREAD_PER_ROW/ 8 ) \
6360+ + (tiitg% THREAD_PER_ROW)* 16 + (i/ 8 )* 8 ) \
6361+ + (tiitg/ THREAD_PER_ROW)% 8 + (i& 7 )* 8 ) = temp_a[i/4 ][i%4 ];
63616362 }
63626363
6363- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6364+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)* 8 * 32 + 8 * (tiitg/ THREAD_PER_COL)) = *((device float2x4 *) y);
63646365
63656366 il = (il + 2 < nl) ? il + 2 : il % 2 ;
63666367 x = (il < 2 ) ? x + (2 +nl-1 )/nl : x;
@@ -6369,53 +6370,64 @@ kernel void kernel_mul_mm(device const uchar * src0,
63696370 threadgroup_barrier (mem_flags::mem_threadgroup);
63706371
63716372 // load matrices from threadgroup memory and conduct outer products
6372- threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6373- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6373+ threadgroup T * lsma = (sa + THREAD_MAT_M* SG_MAT_SIZE* (sgitg% 2 ));
6374+ threadgroup float * lsmb = (sb + THREAD_MAT_N* SG_MAT_SIZE* (sgitg/ 2 ));
63746375
63756376 #pragma unroll(4)
6376- for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6377+ for (short ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
63776378 #pragma unroll(4)
6378- for (int i = 0 ; i < 4 ; i++) {
6379- simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
6379+ for (short i = 0 ; i < 4 ; i++) {
6380+ simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
63806381 }
63816382 simdgroup_barrier (mem_flags::mem_none);
63826383 #pragma unroll(2)
6383- for (int i = 0 ; i < 2 ; i++) {
6384- simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
6384+ for (short i = 0 ; i < 2 ; i++) {
6385+ simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
63856386 }
63866387
6387- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6388- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6388+ lsma += BLOCK_SIZE_M/ SG_MAT_ROW * SG_MAT_SIZE;
6389+ lsmb += BLOCK_SIZE_N/ SG_MAT_ROW * SG_MAT_SIZE;
63896390
63906391 #pragma unroll(8)
6391- for (int i = 0 ; i < 8 ; i++){
6392- simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
6392+ for (short i = 0 ; i < 8 ; i++){
6393+ simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
63936394 }
63946395 }
63956396 }
63966397
63976398 if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
63986399 device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
63996400 + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6400- for (int i = 0 ; i < 8 ; i++) {
6401- simdgroup_store (c_res [i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6401+ for (short i = 0 ; i < 8 ; i++) {
6402+ simdgroup_store (mc [i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
64026403 }
64036404 } else {
64046405 // block is smaller than 64x32, we should avoid writing data outside of the matrix
64056406 threadgroup_barrier (mem_flags::mem_threadgroup);
6406- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6407- + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6408- for (int i = 0 ; i < 8 ; i++) {
6409- simdgroup_store (c_res [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6407+ threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
6408+ + 32 * (sgitg&1 ) + (16 * (sgitg>>1 ))* BLOCK_SIZE_M;
6409+ for (short i = 0 ; i < 8 ; i++) {
6410+ simdgroup_store (mc [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M* (i/4 ), BLOCK_SIZE_M);
64106411 }
64116412
64126413 threadgroup_barrier (mem_flags::mem_threadgroup);
64136414
6414- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
64156415 if (sgitg == 0 ) {
6416- for (int i = 0 ; i < n_rows; i++) {
6417- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6418- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6416+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6417+ device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
6418+ device float4 * D4 = (device float4 *) D;
6419+
6420+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6421+ threadgroup float4 * C4 = (threadgroup float4 *) C;
6422+
6423+ int i = 0 ;
6424+ for (; i < n_rows/4 ; i++) {
6425+ *(D4 + i) = *(C4 + i);
6426+ }
6427+
6428+ i *= 4 ;
6429+ for (; i < n_rows; i++) {
6430+ *(D + i) = *(C + i);
64196431 }
64206432 }
64216433 }
0 commit comments