Skip to content

Commit c4711d8

Browse files
committed
Refactor mmq caching
1 parent 2d6efa4 commit c4711d8

File tree

3 files changed

+273
-139
lines changed

3 files changed

+273
-139
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 24 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -81,33 +81,23 @@ layout (constant_id = 10) const uint WARP = 32;
8181

8282
#ifdef COOPMAT
8383
#define SHMEM_STRIDE (BK / 4 + 4)
84-
#else
85-
#define SHMEM_STRIDE (BK / 4 + 1)
8684
#endif
8785

88-
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
89-
90-
#ifdef DATA_A_QUANT_K
91-
#define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1)
92-
shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE];
93-
#endif
86+
#define MMQ_SHMEM
9487

95-
#ifndef COOPMAT
96-
#if QUANT_AUXF == 1
97-
shared FLOAT_TYPE buf_a_dm[BM];
98-
#else
99-
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
100-
#endif
101-
#endif
88+
#include "mul_mmq_shmem_types.glsl"
10289

103-
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
104-
#ifndef COOPMAT
105-
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
106-
#endif
90+
// Shared memory cache
91+
shared block_a_cache buf_a[BM];
92+
shared block_b_cache buf_b[BN];
93+
// Register cache
94+
block_a_cache cache_a[WMITER * TM];
95+
block_b_cache cache_b[TN];
10796

108-
#define LOAD_VEC_A (4 * QUANT_R)
97+
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
10998
#define LOAD_VEC_B 16
11099

100+
// TODO: Recheck if this can work with mul_mat_id
111101
#ifdef MUL_MAT_ID
112102
shared u16vec2 row_ids[4096];
113103
#endif // MUL_MAT_ID
@@ -230,13 +220,6 @@ void main() {
230220
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
231221
}
232222
#else
233-
int32_t cache_a_qs[WMITER * TM * BK / 4];
234-
235-
#ifdef DATA_A_QUANT_K
236-
uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32];
237-
#endif
238-
239-
int32_t cache_b_qs[TN * BK / 4];
240223

241224
ACC_TYPE sums[WMITER * TM * WNITER * TN];
242225

@@ -245,40 +228,13 @@ void main() {
245228
}
246229
#endif
247230

248-
#if QUANT_AUXF == 1
249-
FLOAT_TYPE cache_a_dm[WMITER * TM];
250-
#else
251-
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
252-
#endif
253-
254-
FLOAT_TYPE_VEC2 cache_b_ds[TN];
255-
256231
for (uint block = start_k; block < end_k; block += BK) {
257232
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
258233
const uint buf_ib = loadc_a + l;
259234
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
260235
const uint iqs = loadr_a;
261236

262-
if (iqs == 0) {
263-
#if QUANT_AUXF == 1
264-
buf_a_dm[buf_ib] = get_d(ib);
265-
#else
266-
buf_a_dm[buf_ib] = get_dm(ib);
267-
#endif
268-
}
269-
#if QUANT_R == 1
270-
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
271-
#else
272-
const i32vec2 vals = repack(ib, iqs);
273-
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
274-
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
275-
#endif
276-
277-
#ifdef DATA_A_QUANT_K
278-
if (iqs % 4 == 0) {
279-
buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs);
280-
}
281-
#endif
237+
block_a_to_shmem(buf_ib, ib, iqs);
282238
}
283239
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
284240
#ifdef MUL_MAT_ID
@@ -297,13 +253,13 @@ void main() {
297253
const uint buf_ib = loadc_b + l;
298254

299255
if (iqs == 0) {
300-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
256+
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
301257
}
302258
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
303-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
304-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
305-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
306-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
259+
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
260+
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
261+
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
262+
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
307263
}
308264

309265
barrier();
@@ -346,25 +302,19 @@ void main() {
346302
// Load from shared into cache
347303
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
348304
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
349-
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
350-
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
351-
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
352-
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
353-
}
354-
#ifdef DATA_A_QUANT_K
355-
[[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) {
356-
cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s];
357-
}
358-
#endif
305+
const uint reg_ib = wsir * TM + cr;
306+
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
307+
308+
block_a_to_registers(reg_ib, buf_ib);
359309
}
360310
}
361311

362312
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
363313
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
364314
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
365-
cache_b_ds[cc] = buf_b_ds[ib];
366-
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
367-
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
315+
cache_b[cc].ds = buf_b[ib].ds;
316+
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
317+
cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs];
368318
}
369319
}
370320

@@ -374,44 +324,7 @@ void main() {
374324
const uint cache_a_idx = wsir * TM + cr;
375325
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
376326

377-
#if defined(DATA_A_QUANT_LEGACY)
378-
int32_t q_sum = 0;
379-
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
380-
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
381-
cache_b_qs[cc * (BK / 4) + idx_k]);
382-
}
383-
384-
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
385-
#elif defined(DATA_A_QUANT_K)
386-
int32_t sum_d = 0;
387-
int32_t sum_m = 0;
388-
389-
const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32];
390-
const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1];
391-
int32_t scale_m = scale0 >> 4;
392-
scale_m |= scale_m << 8;
393-
scale_m |= scale_m << 16;
394-
395-
[[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) {
396-
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
397-
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF);
398-
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
399-
}
400-
401-
scale_m = scale1 >> 4;
402-
scale_m |= scale_m << 8;
403-
scale_m |= scale_m << 16;
404-
405-
[[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) {
406-
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
407-
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF);
408-
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
409-
}
410-
411-
sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
412-
#else
413-
#error unsupported
414-
#endif
327+
sums[sums_idx] += mmq_dot_product(cache_a_idx, cc);
415328
}
416329
}
417330
}

0 commit comments

Comments
 (0)