@@ -81,33 +81,23 @@ layout (constant_id = 10) const uint WARP = 32;
8181
8282#ifdef COOPMAT
8383#define SHMEM_STRIDE (BK / 4 + 4)
84- #else
85- #define SHMEM_STRIDE (BK / 4 + 1)
8684#endif
8785
88- shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
89-
90- #ifdef DATA_A_QUANT_K
91- #define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1)
92- shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE];
93- #endif
86+ #define MMQ_SHMEM
9487
95- #ifndef COOPMAT
96- #if QUANT_AUXF == 1
97- shared FLOAT_TYPE buf_a_dm[BM];
98- #else
99- shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
100- #endif
101- #endif
88+ #include "mul_mmq_shmem_types.glsl"
10289
103- shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
104- #ifndef COOPMAT
105- shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
106- #endif
90+ // Shared memory cache
91+ shared block_a_cache buf_a[BM];
92+ shared block_b_cache buf_b[BN];
93+ // Register cache
94+ block_a_cache cache_a[WMITER * TM];
95+ block_b_cache cache_b[TN];
10796
108- #define LOAD_VEC_A (4 * QUANT_R )
97+ #define LOAD_VEC_A (4 * QUANT_R_MMQ )
10998#define LOAD_VEC_B 16
11099
100+ // TODO: Recheck if this can work with mul_mat_id
111101#ifdef MUL_MAT_ID
112102shared u16vec2 row_ids[4096];
113103#endif // MUL_MAT_ID
@@ -230,13 +220,6 @@ void main() {
230220 sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
231221 }
232222#else
233- int32_t cache_a_qs[WMITER * TM * BK / 4];
234-
235- #ifdef DATA_A_QUANT_K
236- uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32];
237- #endif
238-
239- int32_t cache_b_qs[TN * BK / 4];
240223
241224 ACC_TYPE sums[WMITER * TM * WNITER * TN];
242225
@@ -245,40 +228,13 @@ void main() {
245228 }
246229#endif
247230
248- #if QUANT_AUXF == 1
249- FLOAT_TYPE cache_a_dm[WMITER * TM];
250- #else
251- FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
252- #endif
253-
254- FLOAT_TYPE_VEC2 cache_b_ds[TN];
255-
256231 for (uint block = start_k; block < end_k; block += BK) {
257232 [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
258233 const uint buf_ib = loadc_a + l;
259234 const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
260235 const uint iqs = loadr_a;
261236
262- if (iqs == 0) {
263- #if QUANT_AUXF == 1
264- buf_a_dm[buf_ib] = get_d(ib);
265- #else
266- buf_a_dm[buf_ib] = get_dm(ib);
267- #endif
268- }
269- #if QUANT_R == 1
270- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
271- #else
272- const i32vec2 vals = repack(ib, iqs);
273- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
274- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
275- #endif
276-
277- #ifdef DATA_A_QUANT_K
278- if (iqs % 4 == 0) {
279- buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs);
280- }
281- #endif
237+ block_a_to_shmem(buf_ib, ib, iqs);
282238 }
283239 [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
284240#ifdef MUL_MAT_ID
@@ -297,13 +253,13 @@ void main() {
297253 const uint buf_ib = loadc_b + l;
298254
299255 if (iqs == 0) {
300- buf_b_ds [buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
256+ buf_b [buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
301257 }
302258 const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
303- buf_b_qs [buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
304- buf_b_qs [buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
305- buf_b_qs [buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
306- buf_b_qs [buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
259+ buf_b [buf_ib].qs[ iqs * 4 ] = values.x;
260+ buf_b [buf_ib].qs[ iqs * 4 + 1] = values.y;
261+ buf_b [buf_ib].qs[ iqs * 4 + 2] = values.z;
262+ buf_b [buf_ib].qs[ iqs * 4 + 3] = values.w;
307263 }
308264
309265 barrier();
@@ -346,25 +302,19 @@ void main() {
346302 // Load from shared into cache
347303 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
348304 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
349- const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
350- cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
351- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
352- cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
353- }
354- #ifdef DATA_A_QUANT_K
355- [[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) {
356- cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s];
357- }
358- #endif
305+ const uint reg_ib = wsir * TM + cr;
306+ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
307+
308+ block_a_to_registers(reg_ib, buf_ib);
359309 }
360310 }
361311
362312 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
363313 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
364314 const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
365- cache_b_ds [cc] = buf_b_ds [ib];
366- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k ++) {
367- cache_b_qs [cc * (BK / 4) + idx_k] = buf_b_qs [ib * SHMEM_STRIDE + idx_k ];
315+ cache_b [cc].ds = buf_b [ib].ds ;
316+ [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs ++) {
317+ cache_b [cc].qs[iqs] = buf_b [ib].qs[iqs ];
368318 }
369319 }
370320
@@ -374,44 +324,7 @@ void main() {
374324 const uint cache_a_idx = wsir * TM + cr;
375325 const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
376326
377- #if defined(DATA_A_QUANT_LEGACY)
378- int32_t q_sum = 0;
379- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
380- q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
381- cache_b_qs[cc * (BK / 4) + idx_k]);
382- }
383-
384- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
385- #elif defined(DATA_A_QUANT_K)
386- int32_t sum_d = 0;
387- int32_t sum_m = 0;
388-
389- const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32];
390- const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1];
391- int32_t scale_m = scale0 >> 4;
392- scale_m |= scale_m << 8;
393- scale_m |= scale_m << 16;
394-
395- [[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) {
396- sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
397- cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF);
398- sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
399- }
400-
401- scale_m = scale1 >> 4;
402- scale_m |= scale_m << 8;
403- scale_m |= scale_m << 16;
404-
405- [[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) {
406- sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
407- cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF);
408- sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
409- }
410-
411- sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
412- #else
413- #error unsupported
414- #endif
327+ sums[sums_idx] += mmq_dot_product(cache_a_idx, cc);
415328 }
416329 }
417330 }
0 commit comments