Skip to content

Commit 0090950

Browse files
authored
vulkan: In coopmat2 mmq, load q4_k/q5_k scales through shared memory (#12833)
q4_k and q5_k had a lot of redundant global loads where the same 16B of scale information is repeatedly loaded and decoded during each loop iteration. This change restructures the loops to more explicitly iterate over whole blocks in the outer loop (with unrolled inner loop) and to copy/decode the scale data into shared memory once at the start of each outer loop. The copy is pipelined so the scale load from global memory is relatively cheap. This improves q4_k/q5_k model prompt processing performance by around 5-7%. I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k and hurt for q4_0. The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped variants isn't used as often as it originally was (e.g. due to the padded_N change), so I trimmed it down to offset some of the new complexity of the semi-manual loop unrolling.
1 parent 7ecd780 commit 0090950

File tree

3 files changed

+235
-42
lines changed

3 files changed

+235
-42
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,6 +4194,12 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
41944194
if (split_k == 3) {
41954195
split_k = 2;
41964196
}
4197+
if (ctx->device->coopmat2) {
4198+
// coopmat2 shader expects splits to be aligned to 256
4199+
while (split_k > 1 && ((k / split_k) % 256) != 0) {
4200+
split_k /= 2;
4201+
}
4202+
}
41974203
}
41984204
}
41994205

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
167167
block_q4_K_packed128 block;
168168
};
169169

170+
#if defined(IS_MUL_MM2)
171+
172+
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
173+
// into shared memory and then process the whole tile using those scales.
174+
// There is a fetch function that loads into private variables and then a store
175+
// function that stores into shared memory.
176+
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
177+
// the part that fetches from the structure (which has a different block layout).
178+
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
179+
const uint shAscales_stride = (BM + 2);
180+
// 1 scale per 32 elements -> 8 scales per block, per row
181+
shared vec2 shAscales[8 * shAscales_stride];
182+
uvec4 row_v;
183+
#endif
184+
185+
#if defined(DATA_A_Q4_K)
186+
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
187+
188+
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
189+
{
190+
uint tids_per_row = BLOCK_SIZE / BM;
191+
uint is_per_tid = 8 / tids_per_row;
192+
uint is_start = is_per_tid * (tid % tids_per_row);
193+
uint tid_row = tid / tids_per_row;
194+
195+
uint row = ir_BM + tid_row;
196+
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
197+
if (in_bounds || row < p.M) {
198+
row_v = data_a_q4_k_packed128[block_index].q4k[0];
199+
}
200+
}
201+
#endif
202+
#if defined(DATA_A_Q5_K)
203+
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
204+
205+
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
206+
{
207+
uint tids_per_row = BLOCK_SIZE / BM;
208+
uint is_per_tid = 8 / tids_per_row;
209+
uint is_start = is_per_tid * (tid % tids_per_row);
210+
uint tid_row = tid / tids_per_row;
211+
212+
uint row = ir_BM + tid_row;
213+
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
214+
if (in_bounds || row < p.M) {
215+
row_v = data_a_q5_k_packed128[block_index].q5k[0];
216+
}
217+
}
218+
#endif
219+
220+
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
221+
void store_scalesQ4_K(uint tid)
222+
{
223+
barrier();
224+
225+
uint tids_per_row = BLOCK_SIZE / BM;
226+
uint is_per_tid = 8 / tids_per_row;
227+
uint is_start = is_per_tid * (tid % tids_per_row);
228+
uint tid_row = tid / tids_per_row;
229+
230+
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
231+
uint is = idx + is_start;
232+
uvec4 v = row_v;
233+
const vec2 loadd = vec2(unpackFloat2x16(v.x));
234+
235+
uint32_t sc;
236+
uint32_t mbyte;
237+
238+
uint32_t scale0 = v.y;
239+
uint32_t scale4 = v.z;
240+
uint32_t scale8 = v.w;
241+
242+
uint32_t sc_lo = scale0;
243+
uint32_t mb_lo = scale4;
244+
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
245+
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
246+
247+
sc = is < 4 ? sc_lo : sc_hi;
248+
mbyte = is < 4 ? mb_lo : mb_hi;
249+
sc = sc >> (8 * (is & 3));
250+
mbyte = mbyte >> (8 * (is & 3));
251+
sc &= 0x3F;
252+
mbyte &= 0x3F;
253+
254+
const float d = loadd.x * float(sc);
255+
const float m = loadd.y * float(mbyte);
256+
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
257+
}
258+
259+
barrier();
260+
}
261+
#endif
262+
263+
#endif
264+
170265
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
171266
{
172267
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
176271
const uint b = (idx & 0x20) >> 5; // 0,1
177272
const uint is = (idx & 0xE0) >> 5; // 0..7
178273

274+
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
275+
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
276+
float d = v.x;
277+
float m = v.y;
278+
#else
179279
uvec4 v = bl128.block.q4k[0];
180-
181280
const vec2 loadd = vec2(unpackFloat2x16(v.x));
182281

183282
uint32_t sc;
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
201300

202301
const float d = loadd.x * float(sc);
203302
const float m = loadd.y * float(mbyte);
303+
#endif
204304

205305
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
206306
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
231331
const uint b = (idx & 0x20) >> 5; // 0,1
232332
const uint is = (idx & 0xE0) >> 5; // 0..7
233333

334+
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
335+
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
336+
float d = v.x;
337+
float m = v.y;
338+
#else
234339
uvec4 v = bl128.block.q5k[0];
235340

236341
const f16vec2 loadd = unpackFloat2x16(v.x);
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
256361

257362
const float16_t d = loadd.x * float16_t(sc);
258363
const float16_t m = loadd.y * float16_t(mbyte);
364+
#endif
259365

260366
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
261367
qh = ((qh >> is) & 0x101) << 4;
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
264370
qs = (qs >> (b * 4)) & 0x0F0F;
265371
qs = unpack8(qs | qh)[idx & 1];
266372

267-
float16_t ret = d * (float16_t(qs)) - m;
373+
float ret = d * float(qs) - m;
268374

269-
return ret;
375+
return float16_t(ret);
270376
}
271377

272378
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
564670
#define dequantFuncA dequantFuncQ3_K
565671
#elif defined(DATA_A_Q4_K)
566672
#define dequantFuncA dequantFuncQ4_K
673+
#define fetch_scales fetch_scalesQ4_K
674+
#define store_scales store_scalesQ4_K
567675
#elif defined(DATA_A_Q5_K)
568676
#define dequantFuncA dequantFuncQ5_K
677+
#define fetch_scales fetch_scalesQ5_K
678+
#define store_scales store_scalesQ4_K
569679
#elif defined(DATA_A_Q6_K)
570680
#define dequantFuncA dequantFuncQ6_K
571681
#elif defined(DATA_A_IQ1_S)

0 commit comments

Comments
 (0)