Skip to content

Commit e978f66

Browse files
committed
Pack q2_k blocks into caches of 32
1 parent ded8089 commit e978f66

File tree

4 files changed

+28
-78
lines changed

4 files changed

+28
-78
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
25122512
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
25132513
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
25142514

2515+
// Integer MMQ has a smaller shared memory profile, but heavier register use
25152516
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
25162517
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
25172518
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
@@ -3149,7 +3150,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
31493150
}
31503151
// reusing CREATE_MM from the fp32 path
31513152
if ((device->coopmat2 || device->coopmat_support)
3152-
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3153+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
31533154
&& !device->coopmat_bf16_support
31543155
#endif
31553156
) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ layout (constant_id = 10) const uint WARP = 32;
8282
#endif
8383

8484
// Shared memory cache
85-
shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR];
86-
shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR];
85+
shared block_a_cache buf_a[BM * BK_STEP];
86+
shared block_b_cache buf_b[BN * BK_STEP];
8787
// Register cache
8888
block_a_cache cache_a[WMITER * TM];
8989
block_b_cache cache_b;
@@ -195,7 +195,7 @@ void main() {
195195
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
196196
const uint iqs = loadr_a;
197197

198-
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
198+
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
199199
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
200200
}
201201
}
@@ -213,7 +213,7 @@ void main() {
213213
const uint iqs = loadr_b;
214214
#endif
215215

216-
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
216+
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
217217
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
218218
}
219219
}
@@ -223,7 +223,7 @@ void main() {
223223
pos_a_ib += BK_STEP;
224224
pos_b_ib += BK_STEP;
225225

226-
for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
226+
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
227227
// Load from shared into cache
228228
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
229229
[[unroll]] for (uint cr = 0; cr < TM; cr++) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -233,63 +233,53 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233
#ifdef MMQ_SHMEM
234234
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235
const uint ib_k = ib / 8;
236-
const uint iqs_k = (ib % 8) * 8 + iqs;
236+
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
237237

238238
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
239-
// const uint qs_shift = ((iqs_k % 32) / 8) * 2;
239+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
240240

241241
// Repack 4x4 quants into one int
242-
// const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243-
// const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244-
// const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245-
// const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
242+
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243+
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244+
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245+
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
246246

247-
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
247+
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
248248

249249
if (iqs == 0) {
250250
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
251-
buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]);
252-
}
253-
if (iqs == 1) {
254-
buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]);
251+
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
255252
}
256253
}
257254

258255
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
259256
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
257+
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
260258

261259
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
262-
cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
263-
}
264-
265-
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
266260
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
267261
}
268262
}
269263

270264
ACC_TYPE mmq_dot_product(const uint ib_a) {
271-
float sum_d = 0;
272-
float sum_m = 0;
265+
int32_t sum_d = 0;
266+
int32_t sum_m = 0;
273267

274268
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
275-
const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
276-
[[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) {
277-
const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)];
278-
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
279-
const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303);
280-
281-
sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
282-
sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
283-
}
269+
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
270+
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
271+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
272+
273+
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
274+
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
284275
}
285276

286-
return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m);
277+
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
287278
}
288279
#endif // MMQ_SHMEM
289280
#endif
290281

291282
#ifdef MMQ_SHMEM
292-
#if defined(DATA_A_QUANT_LEGACY)
293283
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
294284
const uint ib_outer = ib / 4;
295285
const uint ib_inner = ib % 4;
@@ -311,33 +301,6 @@ void block_b_to_registers(const uint ib) {
311301
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
312302
}
313303
}
314-
#elif defined(DATA_A_QUANT_K)
315-
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
316-
const uint ib_outer = ib / 4;
317-
318-
buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
319-
buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]);
320-
321-
[[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) {
322-
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
323-
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
324-
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y;
325-
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z;
326-
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w;
327-
}
328-
}
329-
330-
void block_b_to_registers(const uint ib) {
331-
[[unroll]] for (uint i = 0; i < 4; i++) {
332-
cache_b.ds[i] = buf_b[ib].ds[i];
333-
}
334-
[[unroll]] for (uint iqs = 0; iqs < 32; iqs++) {
335-
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
336-
}
337-
}
338-
#else
339-
#error unimplemented
340-
#endif
341304
#endif
342305

343306
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,17 @@ struct block_a_cache {
3131
FLOAT_TYPE dm;
3232
};
3333
#elif defined(DATA_A_Q2_K)
34-
#define QUANT_R_MMQ 1
34+
#define QUANT_R_MMQ 4
3535
struct block_a_cache
3636
{
37-
uint32_t qs[8];
38-
u8vec4 scales[2];
37+
uint32_t qs[2];
38+
u8vec2 scales;
3939
FLOAT_TYPE_VEC2 dm;
4040
};
4141
#endif
4242

43-
#if defined(DATA_A_QUANT_LEGACY)
44-
#define QUANT_BLOCK_FACTOR 1
45-
4643
struct block_b_cache
4744
{
4845
int32_t qs[8];
4946
FLOAT_TYPE_VEC2 ds;
5047
};
51-
#elif defined(DATA_A_QUANT_K)
52-
#define QUANT_BLOCK_FACTOR 4
53-
54-
struct block_b_cache
55-
{
56-
int32_t qs[32];
57-
FLOAT_TYPE_VEC2 ds[4];
58-
};
59-
#else
60-
#error unimplemented
61-
#endif

0 commit comments

Comments
 (0)