@@ -98,11 +98,11 @@ layout (constant_id = 12) const uint LOAD_VEC_B_SHIFT = 0;
9898#ifdef COOPMAT
9999#define SHMEM_STRIDE (BK + 8)
100100#else
101- #define SHMEM_STRIDE (BK + 1)
101+ #define SHMEM_STRIDE (BK / 2 + 1)
102102#endif
103103
104- shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
105- shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104+ shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
105+ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
106106
107107#ifdef MUL_MAT_ID
108108shared u16vec2 row_ids[3072];
@@ -223,8 +223,8 @@ void main() {
223223 }
224224#else
225225 ACC_TYPE sums[WMITER * TM * WNITER * TN];
226- FLOAT_TYPE cache_a[WMITER * TM];
227- FLOAT_TYPE cache_b[TN];
226+ FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
227+ FLOAT_TYPE_VEC2 cache_b[TN];
228228
229229 [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
230230 sums[i] = ACC_TYPE(0.0f);
@@ -262,7 +262,7 @@ void main() {
262262 }
263263 }
264264#else
265- [[unroll]] for (uint i = 0; i < BK; i++) {
265+ [[unroll]] for (uint i = 0; i < BK / 2 ; i++) {
266266 // Load from shared into cache
267267 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
268268 [[unroll]] for (uint j = 0; j < TM; j++) {
@@ -278,7 +278,7 @@ void main() {
278278 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
279279 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
280280 const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
281- sums[sums_idx] = fma(ACC_TYPE (cache_a[wsir * TM + cr]), ACC_TYPE (cache_b[cc]), sums[sums_idx] );
281+ sums[sums_idx] += dot(ACC_TYPE_VEC2 (cache_a[wsir * TM + cr]), ACC_TYPE_VEC2 (cache_b[cc]));
282282 }
283283 }
284284 }
0 commit comments