@@ -8145,17 +8145,24 @@ kernel void kernel_mul_mm(
81458145 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
81468146 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096 );
81478147
8148- const int r0 = tgpig.y ;
8149- const int r1 = tgpig.x ;
8148+ constexpr int NR0 = 64 ;
8149+ constexpr int NR1 = 32 ;
8150+
8151+ constexpr int NK = 32 ;
8152+ constexpr int NL0 = NK/16 ;
8153+ constexpr int NL1 = NK/8 ;
8154+
81508155 const int im = tgpig.z ;
8156+ const int r0 = tgpig.y *NR0;
8157+ const int r1 = tgpig.x *NR1;
81518158
81528159 // if this block is of 64x32 shape or smaller
8153- const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M ) ? (args.ne0 - r0*BLOCK_SIZE_M ) : BLOCK_SIZE_M ;
8154- const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N ) ? (args.ne1 - r1*BLOCK_SIZE_N ) : BLOCK_SIZE_N ;
8160+ const short nr0 = (args.ne0 - r0 < NR0 ) ? (args.ne0 - r0) : NR0 ;
8161+ const short nr1 = (args.ne1 - r1 < NR1 ) ? (args.ne1 - r1) : NR1 ;
81558162
81568163 // a thread shouldn't load data outside of the matrix
8157- const short thread_row = ((short )tiitg/THREAD_PER_ROW ) < n_rows ? ((short )tiitg/THREAD_PER_ROW ) : n_rows - 1 ;
8158- const short thread_col = ((short )tiitg/THREAD_PER_COL ) < n_cols ? ((short )tiitg/THREAD_PER_COL ) : n_cols - 1 ;
8164+ const short lr0 = ((short )tiitg/NL0 ) < nr0 ? ((short )tiitg/NL0 ) : nr0 - 1 ; // 0 .. 63
8165+ const short lr1 = ((short )tiitg/NL1 ) < nr1 ? ((short )tiitg/NL1 ) : nr1 - 1 ; // 0 .. 31
81598166
81608167 S0_8x8 ma[4 ];
81618168 S1_8x8 mb[2 ];
@@ -8166,35 +8173,44 @@ kernel void kernel_mul_mm(
81668173 mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
81678174 }
81688175
8169- short il = (tiitg % THREAD_PER_ROW);
8176+ const short il0 = (tiitg % NL0);
8177+
8178+ short il = il0;
81708179
81718180 const int i12 = im%args.ne12 ;
81728181 const int i13 = im/args.ne12 ;
81738182
81748183 const uint64_t offset0 = (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
8175- const short offset1 = il /nl;
8184+ const short offset1 = il0 /nl;
81768185
8177- device const block_q * x = (device const block_q *)(src0
8178- + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8186+ device const block_q * x = (device const block_q *)(src0 + args.nb01 *(r0 + lr0) + offset0) + offset1;
81798187
8180- const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) );
8188+ const short iy = 8 * (tiitg % NL1 );
81818189
81828190 device const T1 * y = (device const T1 *)(src1
81838191 + args.nb13 *i13
81848192 + args.nb12 *i12
8185- + args.nb11 *(r1*BLOCK_SIZE_N + thread_col )
8193+ + args.nb11 *(r1 + lr1 )
81868194 + args.nb10 *iy);
81878195
8188- for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K ) {
8196+ for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += NK ) {
81898197 // load data and store to threadgroup memory
81908198 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81918199 threadgroup_barrier (mem_flags::mem_threadgroup);
81928200
81938201 // no need for dequantization
81948202 for (short i = 0 ; i < 16 ; i++) {
8195- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8196- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
8197- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
8203+ const short sx = 2 *il0 + i/8 ;
8204+ const short sy = (tiitg/NL0)/8 ;
8205+
8206+ // const short lx = i%8;
8207+ // const short ly = (tiitg/NL0)%8;
8208+ const short lx = (tiitg/NL0)%8 ;
8209+ const short ly = i%8 ;
8210+
8211+ const short ib = 8 *sx + sy;
8212+
8213+ *(sa + 64 *ib + 8 *ly + lx) = loop_k + 16 *il + i < args.ne00 ? *((device T0 *) x + i) : 0 ;
81988214 }
81998215 } else {
82008216 S0_4x4 temp_a;
@@ -8203,91 +8219,122 @@ kernel void kernel_mul_mm(
82038219 threadgroup_barrier (mem_flags::mem_threadgroup);
82048220
82058221 FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
8206- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8207- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
8208- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
8222+ const short sx = 2 *il0 + i/8 ;
8223+ const short sy = (tiitg/NL0)/8 ;
8224+
8225+ // const short lx = i%8;
8226+ // const short ly = (tiitg/NL0)%8;
8227+ const short lx = (tiitg/NL0)%8 ;
8228+ const short ly = i%8 ;
8229+
8230+ const short ib = 8 *sx + sy;
8231+
8232+ // NOTE: this is massively slower.. WTF?
8233+ // sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8234+
8235+ *(sa + 64 *ib + 8 *ly + lx) = temp_a[i/4 ][i%4 ];
82098236 }
82108237 }
82118238
82128239 if (FC_mul_mm_bc_inp) {
82138240 for (short i = 0 ; i < 8 ; ++i) {
8214- sb[32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0 ;
8241+ const short sx = (tiitg%NL1);
8242+ const short sy = (tiitg/NL1)/8 ;
8243+
8244+ const short lx = i;
8245+ const short ly = (tiitg/NL1)%8 ;
8246+ // const short lx = (tiitg/NL1)%8;
8247+ // const short ly = i;
8248+
8249+ const short ib = 4 *sx + sy;
8250+
8251+ *(sb + 64 *ib + 8 *ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
82158252 }
82168253 } else {
8217- *(threadgroup S1_2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8254+ const short sx = (tiitg%NL1);
8255+ const short sy = (tiitg/NL1)/8 ;
8256+
8257+ const short dx = sx;
8258+ const short dy = sy;
8259+
8260+ const short ly = (tiitg/NL1)%8 ;
8261+
8262+ const short ib = 4 *sx + sy;
8263+
8264+ *(threadgroup S1_2x4 *)(sb + 64 *ib + 8 *ly) = (S1_2x4)(*((device T1_2x4 *) y));
82188265 }
82198266
82208267 il = (il + 2 < nl) ? il + 2 : il % 2 ;
82218268 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
8222- y += BLOCK_SIZE_K;
82238269
8224- threadgroup_barrier (mem_flags::mem_threadgroup) ;
8270+ y += NK ;
82258271
82268272 // load matrices from threadgroup memory and conduct outer products
8227- threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE *(sgitg%2 ));
8228- threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE *(sgitg/2 ));
8273+ threadgroup const S0 * lsma = (sa + 4 * 64 *(sgitg%2 ));
8274+ threadgroup const S1 * lsmb = (sb + 2 * 64 *(sgitg/2 ));
82298275
8230- #pragma unroll(4)
8231- for (short ik = 0 ; ik < BLOCK_SIZE_K/8 ; ik++) {
8276+ threadgroup_barrier (mem_flags::mem_threadgroup);
8277+
8278+ FOR_UNROLL (short ik = 0 ; ik < NK/8 ; ik++) {
82328279 simdgroup_barrier (mem_flags::mem_none);
82338280
8234- #pragma unroll(4)
8235- for (short i = 0 ; i < 4 ; i++) {
8236- simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
8281+ FOR_UNROLL (short i = 0 ; i < 4 ; i++) {
8282+ simdgroup_load (ma[i], lsma + 64 *i, 8 , 0 , false );
82378283 }
82388284
8239- #pragma unroll(2)
8240- for (short i = 0 ; i < 2 ; i++) {
8241- simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
8285+ simdgroup_barrier (mem_flags::mem_none);
8286+
8287+ FOR_UNROLL (short i = 0 ; i < 2 ; i++) {
8288+ simdgroup_load (mb[i], lsmb + 64 *i, 8 , 0 , false );
82428289 }
82438290
82448291 simdgroup_barrier (mem_flags::mem_none);
82458292
8246- #pragma unroll(8)
8247- for (short i = 0 ; i < 8 ; i++){
8293+ FOR_UNROLL (short i = 0 ; i < 8 ; i++){
82488294 simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
82498295 }
82508296
8251- lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE ;
8252- lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE ;
8297+ lsma += 8 * 64 ;
8298+ lsmb += 4 * 64 ;
82538299 }
82548300 }
82558301
8256- if (!FC_mul_mm_bc_out || (( r0 + 1 ) * BLOCK_SIZE_M <= args.ne0 && ( r1 + 1 ) * BLOCK_SIZE_N <= args.ne1 )) {
8302+ if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1 )) {
82578303 // if no bounds checks on the output are needed, we can directly write to device memory
82588304 device float * C = (device float *) dst +
8259- (BLOCK_SIZE_M * r0 + 32 *(sgitg & 1 )) + \
8260- (BLOCK_SIZE_N * r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
8305+ (r0 + 32 *(sgitg & 1 )) + \
8306+ (r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
82618307
82628308 for (short i = 0 ; i < 8 ; i++) {
8263- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 );
8309+ simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 , 0 , false );
82648310 }
82658311 } else {
82668312 // block is smaller than 64x32, we should avoid writing data outside of the matrix
82678313 threadgroup_barrier (mem_flags::mem_threadgroup);
8268- threadgroup float * temp_str = ((threadgroup float *) shmem) \
8269- + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
8314+
8315+ threadgroup float * temp_str = ((threadgroup float *) shmem) + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*NR0;
8316+
82708317 for (short i = 0 ; i < 8 ; i++) {
8271- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M *(i/4 ), BLOCK_SIZE_M );
8318+ simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *NR0 *(i/4 ), NR0, 0 , false );
82728319 }
82738320
82748321 threadgroup_barrier (mem_flags::mem_threadgroup);
82758322
82768323 if (sgitg == 0 ) {
8277- for (int j = tiitg; j < n_cols ; j += BLOCK_SIZE_N ) {
8278- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1 *args.ne0 ;
8324+ for (int j = tiitg; j < nr1 ; j += NR1 ) {
8325+ device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1 *args.ne0 ;
82798326 device float4 * D4 = (device float4 *) D;
82808327
8281- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M );
8328+ threadgroup float * C = temp_str + (j*NR0 );
82828329 threadgroup float4 * C4 = (threadgroup float4 *) C;
82838330
82848331 int i = 0 ;
8285- for (; i < n_rows /4 ; i++) {
8332+ for (; i < nr0 /4 ; i++) {
82868333 *(D4 + i) = *(C4 + i);
82878334 }
82888335
82898336 i *= 4 ;
8290- for (; i < n_rows ; i++) {
8337+ for (; i < nr0 ; i++) {
82918338 *(D + i) = *(C + i);
82928339 }
82938340 }
0 commit comments