@@ -100,7 +100,6 @@ layout (push_constant) uniform parameter
100100layout (constant_id = 0) const uint BLOCK_SIZE = 64;
101101layout (constant_id = 1) const uint BM = 64;
102102layout (constant_id = 2) const uint BN = 64;
103- layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
104103layout (constant_id = 4) const uint WM = 32;
105104layout (constant_id = 5) const uint WN = 32;
106105layout (constant_id = 6) const uint WMITER = 2;
@@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2;
109108layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
110109layout (constant_id = 10) const uint WARP = 32;
111110
111+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
112+ #define BK 32
113+ #define BK_STEP 4
114+ #else
115+ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
116+ #define BK_STEP 2
117+ #endif
118+
112119#ifdef COOPMAT
113120#define SHMEM_STRIDE (BK / 2 + 4)
114121#else
@@ -244,8 +251,13 @@ void main() {
244251 }
245252#else
246253 ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
254+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
247255 FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
248256 FLOAT_TYPE_VEC4 cache_b;
257+ #else
258+ FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
259+ FLOAT_TYPE_VEC2 cache_b;
260+ #endif
249261
250262 [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
251263 sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
@@ -283,30 +295,41 @@ void main() {
283295 }
284296 }
285297#else
286- [[unroll]] for (uint i = 0; i < BK / 4 ; i++) {
298+ [[unroll]] for (uint i = 0; i < BK / BK_STEP ; i++) {
287299 // Load from shared into cache
288300 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
289- const uint base_a = (warp_r * WM + wsir * WSUBM + tiwr * TM) * SHMEM_STRIDE + 2 * i;
290301 [[unroll]] for (uint j = 0; j < TM; j++) {
291- cache_a[wsir * TM + j].xy = buf_a[base_a + j * SHMEM_STRIDE ];
292- cache_a[wsir * TM + j].zw = buf_a[base_a + j * SHMEM_STRIDE + 1];
302+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
303+ cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ];
304+ cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
305+ #else
306+ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
307+ #endif
293308 }
294309 }
295310
296311 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
297- const uint base_b = (warp_c * WN + wsic * WSUBN + tiwc * TN) * SHMEM_STRIDE + 2 * i;
298312 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
299- cache_b.xy = buf_b[base_b + cc * SHMEM_STRIDE ];
300- cache_b.zw = buf_b[base_b + cc * SHMEM_STRIDE + 1];
313+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
314+ cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ];
315+ cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
316+ #else
317+ cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
318+ #endif
301319
302320 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
303321 [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
304322 // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
305323 const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
324+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
306325 sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
307326 fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
308327 sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
309328 fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
329+ #else
330+ sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
331+ sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
332+ #endif
310333 }
311334 }
312335 }
0 commit comments