Skip to content

Commit ded8089

Browse files
committed
Load 4 quant blocks into shared memory in one step
1 parent 0775df7 commit ded8089

File tree

3 files changed

+129
-75
lines changed

3 files changed

+129
-75
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,13 @@ layout (constant_id = 10) const uint WARP = 32;
7777

7878
#include "mul_mmq_shmem_types.glsl"
7979

80+
#ifndef BK_STEP
81+
#define BK_STEP 4
82+
#endif
83+
8084
// Shared memory cache
81-
shared block_a_cache buf_a[BM];
82-
shared block_b_cache buf_b[BN];
85+
shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR];
86+
shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR];
8387
// Register cache
8488
block_a_cache cache_a[WMITER * TM];
8589
block_b_cache cache_b;
@@ -185,70 +189,64 @@ void main() {
185189
sums[i] = ACC_TYPE_VEC2(0.0f);
186190
}
187191

188-
for (uint block = start_k; block < end_k; block += BK) {
192+
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
189193
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
190194
const uint buf_ib = loadc_a + l;
191195
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
192196
const uint iqs = loadr_a;
193197

194-
block_a_to_shmem(buf_ib, ib, iqs);
198+
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
199+
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
200+
}
195201
}
196202
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
203+
const uint buf_ib = loadc_b + l;
204+
197205
#ifdef MUL_MAT_ID
198-
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
206+
const u16vec2 row_idx = row_ids[ic * BN + buf_ib];
199207
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
200208
const uint ib = idx / 8;
201209
const uint iqs = idx & 0x7;
202210
#else
203-
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
204-
const uint ib_outer = ib / 4;
205-
const uint ib_inner = ib % 4;
211+
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
206212

207213
const uint iqs = loadr_b;
208214
#endif
209215

210-
const uint buf_ib = loadc_b + l;
211-
212-
if (iqs == 0) {
213-
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
216+
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
217+
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
214218
}
215-
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
216-
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
217-
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
218-
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
219-
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
220219
}
221220

222221
barrier();
223222

224-
pos_a_ib += 1;
225-
pos_b_ib += 1;
223+
pos_a_ib += BK_STEP;
224+
pos_b_ib += BK_STEP;
226225

227-
// Load from shared into cache
228-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
229-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
230-
const uint reg_ib = wsir * TM + cr;
231-
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
226+
for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
227+
// Load from shared into cache
228+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
229+
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
230+
const uint reg_ib = wsir * TM + cr;
231+
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
232232

233-
block_a_to_registers(reg_ib, buf_ib);
233+
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
234+
}
234235
}
235-
}
236236

237-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
238-
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
239-
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
240-
cache_b.ds = buf_b[ib].ds;
241-
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
242-
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
243-
}
237+
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
238+
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
239+
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
240+
block_b_to_registers(ib);
244241

245-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
246-
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
247-
const uint cache_a_idx = wsir * TM + cr * 2;
248-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
242+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
243+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
244+
const uint cache_a_idx = wsir * TM + cr * 2;
245+
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
249246

250-
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
251-
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
247+
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
248+
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
249+
}
252250
}
253251
}
254252
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,69 +233,111 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233
#ifdef MMQ_SHMEM
234234
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235
const uint ib_k = ib / 8;
236-
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
236+
const uint iqs_k = (ib % 8) * 8 + iqs;
237237

238238
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
239-
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
239+
// const uint qs_shift = ((iqs_k % 32) / 8) * 2;
240240

241241
// Repack 4x4 quants into one int
242-
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243-
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244-
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245-
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
242+
// const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243+
// const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244+
// const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245+
// const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
246246

247-
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
247+
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
248248

249249
if (iqs == 0) {
250-
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
251250
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
251+
buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]);
252+
}
253+
if (iqs == 1) {
254+
buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]);
252255
}
253256
}
254257

255258
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
256259
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
257-
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
258260

259261
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
262+
cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
263+
}
264+
265+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
260266
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
261267
}
262268
}
263269

264270
ACC_TYPE mmq_dot_product(const uint ib_a) {
265-
int32_t sum_d = 0;
266-
int32_t sum_m = 0;
271+
float sum_d = 0;
272+
float sum_m = 0;
267273

268-
uint8_t scale = cache_a[ib_a].scales[0];
269-
int32_t scale_m = int32_t(scale >> 4);
270-
scale_m |= scale_m << 8;
271-
scale_m |= scale_m << 16;
274+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
275+
const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
276+
[[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) {
277+
const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)];
278+
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
279+
const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303);
280+
281+
sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
282+
sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
283+
}
284+
}
272285

273-
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
274-
const uint qs_shift = iqs * 2;
286+
return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m);
287+
}
288+
#endif // MMQ_SHMEM
289+
#endif
275290

276-
const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303);
291+
#ifdef MMQ_SHMEM
292+
#if defined(DATA_A_QUANT_LEGACY)
293+
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
294+
const uint ib_outer = ib / 4;
295+
const uint ib_inner = ib % 4;
277296

278-
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
279-
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
297+
if (iqs == 0) {
298+
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280299
}
281300

282-
scale = cache_a[ib_a].scales[1];
283-
scale_m = int32_t(scale >> 4);
284-
scale_m |= scale_m << 8;
285-
scale_m |= scale_m << 16;
286-
287-
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
288-
const uint qs_shift = (iqs - 4) * 2;
289-
290-
const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303);
301+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
302+
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
303+
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
304+
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
305+
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
306+
}
291307

292-
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
293-
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
308+
void block_b_to_registers(const uint ib) {
309+
cache_b.ds = buf_b[ib].ds;
310+
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
311+
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
312+
}
313+
}
314+
#elif defined(DATA_A_QUANT_K)
315+
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
316+
const uint ib_outer = ib / 4;
317+
318+
buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
319+
buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]);
320+
321+
[[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) {
322+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
323+
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
324+
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y;
325+
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z;
326+
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w;
294327
}
328+
}
295329

296-
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
330+
void block_b_to_registers(const uint ib) {
331+
[[unroll]] for (uint i = 0; i < 4; i++) {
332+
cache_b.ds[i] = buf_b[ib].ds[i];
333+
}
334+
[[unroll]] for (uint iqs = 0; iqs < 32; iqs++) {
335+
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
336+
}
297337
}
298-
#endif // MMQ_SHMEM
338+
#else
339+
#error unimplemented
340+
#endif
299341
#endif
300342

301343
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,31 @@ struct block_a_cache {
3131
FLOAT_TYPE dm;
3232
};
3333
#elif defined(DATA_A_Q2_K)
34-
#define QUANT_R_MMQ 4
34+
#define QUANT_R_MMQ 1
3535
struct block_a_cache
3636
{
37-
uint32_t qs[2];
38-
u8vec2 scales;
37+
uint32_t qs[8];
38+
u8vec4 scales[2];
3939
FLOAT_TYPE_VEC2 dm;
4040
};
4141
#endif
4242

43+
#if defined(DATA_A_QUANT_LEGACY)
44+
#define QUANT_BLOCK_FACTOR 1
45+
4346
struct block_b_cache
4447
{
4548
int32_t qs[8];
4649
FLOAT_TYPE_VEC2 ds;
4750
};
51+
#elif defined(DATA_A_QUANT_K)
52+
#define QUANT_BLOCK_FACTOR 4
53+
54+
struct block_b_cache
55+
{
56+
int32_t qs[32];
57+
FLOAT_TYPE_VEC2 ds[4];
58+
};
59+
#else
60+
#error unimplemented
61+
#endif

0 commit comments

Comments
 (0)