Skip to content

Commit 26b71e3

Browse files
committed
Revert "vulkan: matmul dequantization improvements (ggml-org#12015)"
This reverts commit fbeda90.
1 parent 6b7d234 commit 26b71e3

File tree

5 files changed

+59
-99
lines changed

5 files changed

+59
-99
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
8282
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
8383
}
8484
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
85-
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
86-
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
87-
return vec4(v0.x, v0.y, v1.x, v1.y);
85+
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
86+
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
87+
return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
8888
}
8989
#endif
9090

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
9292
const uint iqs = idx;
9393

9494
// Load 16b and select the byte for this element
95-
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
95+
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
9696
float16_t ret = float16_t(qs) * d;
9797
return ret;
9898
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 52 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3333

3434
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
35-
#if defined(A_TYPE_PACKED16)
36-
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
37-
#endif
38-
#if defined(A_TYPE_PACKED32)
39-
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
40-
#endif
41-
4235
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
4336
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
4437

@@ -250,100 +243,74 @@ void main() {
250243
#endif
251244
#elif defined(DATA_A_Q4_0)
252245
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
253-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
254-
255-
const uint ib = idx / 4;
256-
const uint iqs = idx & 0x03;
257-
258-
const float d = float(data_a_packed16[ib].d);
259-
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
260-
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
261-
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
262-
263-
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
264-
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
265-
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
266-
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
267-
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
268-
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
269-
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
270-
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
246+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
247+
248+
const uint ib = idx / 16;
249+
const uint iqs = idx & 0xF;
250+
251+
const float d = float(data_a[ib].d);
252+
const uint vui = uint(data_a[ib].qs[iqs]);
253+
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
254+
255+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
256+
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
271257
#elif defined(DATA_A_Q4_1)
272258
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
273-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
274-
275-
const uint ib = idx / 4;
276-
const uint iqs = idx & 0x03;
277-
278-
const float d = float(data_a_packed16[ib].d);
279-
const float m = float(data_a_packed16[ib].m);
280-
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
281-
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
282-
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
283-
284-
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
285-
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
286-
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
287-
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
288-
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
289-
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
290-
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
291-
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
259+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
260+
261+
const uint ib = idx / 16;
262+
const uint iqs = idx & 0xF;
263+
264+
const float d = float(data_a[ib].d);
265+
const float m = float(data_a[ib].m);
266+
const uint vui = uint(data_a[ib].qs[iqs]);
267+
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
268+
269+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
270+
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
292271
#elif defined(DATA_A_Q5_0)
293272
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
294-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
273+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
295274

296-
const uint ib = idx / 8;
297-
const uint iqs = idx & 0x07;
275+
const uint ib = idx / 16;
276+
const uint iqs = idx & 0xF;
298277

299-
const float d = float(data_a_packed16[ib].d);
300-
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
301-
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
302-
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
303-
304-
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
305-
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
278+
const float d = float(data_a[ib].d);
279+
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
280+
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
281+
const uint vui = uint(data_a[ib].qs[iqs]);
282+
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
306283

307284
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
308-
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
309285
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
310-
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
311286
#elif defined(DATA_A_Q5_1)
312287
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
313-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
314-
315-
const uint ib = idx / 8;
316-
const uint iqs = idx & 0x07;
288+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
317289

318-
const float d = float(data_a_packed16[ib].d);
319-
const float m = float(data_a_packed16[ib].m);
320-
const uint uint_qh = data_a_packed16[ib].qh;
321-
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
322-
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
290+
const uint ib = idx / 16;
291+
const uint iqs = idx & 0xF;
323292

324-
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
325-
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
293+
const float d = float(data_a[ib].d);
294+
const float m = float(data_a[ib].m);
295+
const uint uint_qh = data_a[ib].qh;
296+
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
297+
const uint vui = uint(data_a[ib].qs[iqs]);
298+
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
326299

327300
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
328-
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
329301
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
330-
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
331302
#elif defined(DATA_A_Q8_0)
332303
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
333304
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
334305

335-
const uint ib = idx / 8;
336-
const uint iqs = idx & 0x07;
306+
const uint ib = idx / 16;
307+
const uint iqs = (idx & 0xF) * 2;
337308

338-
const float d = float(data_a_packed16[ib].d);
339-
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
340-
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
341-
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
309+
const float d = float(data_a[ib].d);
310+
const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
342311

343312
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
344313
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
345-
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
346-
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
347314
#elif defined(DATA_A_Q2_K)
348315
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
349316
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -656,18 +623,17 @@ void main() {
656623
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
657624
#elif defined(DATA_A_IQ4_NL)
658625
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
659-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
626+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
660627

661-
const uint ib = idx / 8;
662-
const uint iqs = idx & 0x07;
628+
const uint ib = idx / 16;
629+
const uint iqs = idx & 0xF;
663630

664-
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
665-
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
631+
const float d = float(data_a[ib].d);
632+
const uint vui = uint(data_a[ib].qs[iqs]);
633+
const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
666634

667-
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
668-
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
669-
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
670-
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
635+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
636+
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
671637
#endif
672638
}
673639
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct block_q8_0
139139
struct block_q8_0_packed16
140140
{
141141
float16_t d;
142-
int16_t qs[32/2];
142+
uint16_t qs[32/2];
143143
};
144144

145145
#if defined(DATA_A_Q8_0)

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,17 +331,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
331331
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
332332

333333
for (const auto& tname : type_names) {
334-
std::string load_vec_quant = "2";
335-
if ((tname == "q4_0") || (tname == "q4_1"))
336-
load_vec_quant = "8";
337-
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
338-
load_vec_quant = "4";
339-
340334
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
341335
// For unaligned, load one at a time for f32/f16, or two at a time for quants
342-
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
336+
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
343337
// For aligned matmul loads
344-
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
338+
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
345339

346340
// don't generate f32 variants for coopmat2
347341
if (!coopmat2) {

0 commit comments

Comments
 (0)