3131#include "types.comp"
3232
3333#ifndef LOAD_VEC_A
34- #define LOAD_VEC_A 2
34+ #define LOAD_VEC_A 1
3535#endif
3636#ifndef LOAD_VEC_B
37- #define LOAD_VEC_B 2
37+ #define LOAD_VEC_B 1
3838#endif
3939
4040#if !defined(TO_FLOAT_TYPE)
@@ -98,13 +98,13 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
9898layout (constant_id = 10) const uint WARP = 32;
9999
100100#ifdef COOPMAT
101- #define SHMEM_STRIDE (BK / 2 + 4 )
101+ #define SHMEM_STRIDE (BK + 8 )
102102#else
103- #define SHMEM_STRIDE (BK / 2 + 1)
103+ #define SHMEM_STRIDE (BK + 1)
104104#endif
105105
106- shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
107- shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
106+ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
107+ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
108108
109109#define NUM_WARPS (BLOCK_SIZE / WARP)
110110
@@ -302,8 +302,8 @@ void main() {
302302 }
303303#else
304304 ACC_TYPE sums[WMITER * TM * WNITER * TN];
305- FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
306- FLOAT_TYPE_VEC2 cache_b[TN];
305+ FLOAT_TYPE cache_a[WMITER * TM];
306+ FLOAT_TYPE cache_b[TN];
307307
308308 [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
309309 sums[i] = ACC_TYPE(0.0f);
@@ -312,13 +312,13 @@ void main() {
312312
313313 for (uint block = start_k; block < end_k; block += BK) {
314314 [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
315- load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
315+ load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a , end_k);
316316 }
317317 [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
318318#if !defined(MUL_MAT_ID)
319- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
319+ load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b , end_k);
320320#else
321- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
321+ load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b , end_k);
322322#endif
323323 }
324324
@@ -331,17 +331,17 @@ void main() {
331331 [[unroll]] for (uint i = 0; i < BK; i += TK) {
332332 [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
333333 // Load from shared into cache
334- coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
334+ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
335335
336336 [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
337- coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2 , SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
337+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
338338
339339 sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
340340 }
341341 }
342342 }
343343#else
344- [[unroll]] for (uint i = 0; i < BK / 2 ; i++) {
344+ [[unroll]] for (uint i = 0; i < BK; i++) {
345345 // Load from shared into cache
346346 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
347347 [[unroll]] for (uint j = 0; j < TM; j++) {
@@ -357,7 +357,7 @@ void main() {
357357 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
358358 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
359359 const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
360- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x ), ACC_TYPE(cache_b[cc].x ), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]) );
360+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
361361 }
362362 }
363363 }
0 commit comments