@@ -313,12 +313,12 @@ void main() {
313313 sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
314314 }
315315#else
316- ACC_TYPE sums[WMITER * TM * WNITER * TN];
316+ ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2 ];
317317 FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
318- FLOAT_TYPE_VEC2 cache_b[TN] ;
318+ FLOAT_TYPE_VEC2 cache_b;
319319
320- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
321- sums[i] = ACC_TYPE( 0.0f);
320+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2 ; i++) {
321+ sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
322322 }
323323#endif
324324
@@ -360,20 +360,22 @@ void main() {
360360 cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
361361 }
362362 }
363- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
364- [[unroll]] for (uint j = 0; j < TN; j++) {
365- cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
366- }
367363
368- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
370- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
371- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
372- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
364+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
365+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
366+ cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
367+
368+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
370+ // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
371+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
372+ sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
373+ sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
373374 }
374375 }
375376 }
376377 }
378+
377379 }
378380#endif
379381
@@ -388,8 +390,9 @@ void main() {
388390 }
389391 }
390392#else
391- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
392- sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
393+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
394+ sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
395+ sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
393396 }
394397#endif
395398#endif
@@ -463,14 +466,21 @@ void main() {
463466
464467 const u16vec2 row_idx = row_ids[row_i - ic * BN];
465468#endif // MUL_MAT_ID
466- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
469+ [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
470+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
467471#ifdef MUL_MAT_ID
468- if (dr_warp + cr < p.M) {
469- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
472+ if (dr_warp + 2 * cr < p.M) {
473+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
474+ }
475+ if (dr_warp + 2 * cr + 1 < p.M) {
476+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
470477 }
471478#else
472- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
473- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
479+ if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
480+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
481+ }
482+ if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
483+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
474484 }
475485#endif // MUL_MAT_ID
476486 }
0 commit comments