3131#include "types.comp"
3232
3333#ifndef LOAD_VEC_A
34- #define LOAD_VEC_A 1 
34+ #define LOAD_VEC_A 2 
3535#endif
3636#ifndef LOAD_VEC_B
37- #define LOAD_VEC_B 1 
37+ #define LOAD_VEC_B 2 
3838#endif
3939
4040#if !defined(TO_FLOAT_TYPE)
@@ -98,13 +98,13 @@ layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
9898layout (constant_id = 10) const uint WARP = 32;
9999
100100#ifdef COOPMAT
101- #define SHMEM_STRIDE (BK + 8 )
101+ #define SHMEM_STRIDE (BK / 2 + 4 )
102102#else
103- #define SHMEM_STRIDE (BK + 1)
103+ #define SHMEM_STRIDE (BK / 2  + 1)
104104#endif
105105
106- shared FLOAT_TYPE  buf_a[BM * SHMEM_STRIDE];
107- shared FLOAT_TYPE  buf_b[BN * SHMEM_STRIDE];
106+ shared FLOAT_TYPE_VEC2  buf_a[BM * SHMEM_STRIDE];
107+ shared FLOAT_TYPE_VEC2  buf_b[BN * SHMEM_STRIDE];
108108
109109#define NUM_WARPS (BLOCK_SIZE / WARP)
110110
@@ -302,8 +302,8 @@ void main() {
302302    }
303303#else
304304    ACC_TYPE sums[WMITER * TM * WNITER * TN];
305-     FLOAT_TYPE  cache_a[WMITER * TM];
306-     FLOAT_TYPE  cache_b[TN];
305+     FLOAT_TYPE_VEC2  cache_a[WMITER * TM];
306+     FLOAT_TYPE_VEC2  cache_b[TN];
307307
308308    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
309309        sums[i] = ACC_TYPE(0.0f);
@@ -312,13 +312,13 @@ void main() {
312312
313313    for (uint block = start_k; block < end_k; block += BK) {
314314        [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
315-             load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a , end_k);
315+             load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
316316        }
317317        [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
318318#if !defined(MUL_MAT_ID)
319-             load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b , end_k);
319+             load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
320320#else
321-             load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b , end_k);
321+             load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
322322#endif
323323        }
324324
@@ -331,17 +331,17 @@ void main() {
331331        [[unroll]] for (uint i = 0; i < BK; i += TK) {
332332            [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
333333                // Load from shared into cache
334-                 coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
334+                 coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
335335
336336                [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
337-                     coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
337+                     coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
338338
339339                    sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
340340                }
341341            }
342342        }
343343#else
344-         [[unroll]] for (uint i = 0; i < BK; i++) {
344+         [[unroll]] for (uint i = 0; i < BK / 2 ; i++) {
345345            // Load from shared into cache
346346            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
347347                [[unroll]] for (uint j = 0; j < TM; j++) {
@@ -357,7 +357,7 @@ void main() {
357357                    [[unroll]] for (uint cc = 0; cc < TN; cc++) {
358358                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
359359                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
360-                             sums[sums_idx] = fma(ACC_TYPE (cache_a[wsir * TM + cr]), ACC_TYPE (cache_b[cc]), sums[sums_idx] );
360+                             sums[sums_idx] += dot(ACC_TYPE_VEC2 (cache_a[wsir * TM + cr]), ACC_TYPE_VEC2 (cache_b[cc]));
361361                        }
362362                    }
363363                }
0 commit comments