Skip to content

Commit cc28742

Browse files
committed
q2_k better dequant
1 parent 91f1d9c commit cc28742

File tree

2 files changed

+42
-41
lines changed

2 files changed

+42
-41
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,19 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4040

4141
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
4242
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43-
f16vec2 d = data_a[ib0 + i].d;
43+
const f16vec2 d = data_a[ib0 + i].d;
4444
const FLOAT_TYPE dall = d.x;
4545
const FLOAT_TYPE dmin = d.y;
4646

4747
sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
4848
sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
4949
barrier();
5050

51-
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2];
52-
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
53-
uvec2 qs0 = uvec2(unpack8(qs0_u16));
54-
uvec2 qs16 = uvec2(unpack8(qs16_u16));
51+
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
52+
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
53+
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
54+
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
55+
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
5556

5657
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
5758
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
@@ -66,14 +67,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6667
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
6768
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
6869
[[unroll]] for (int l = 0; l < 2; ++l) {
69-
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3),
70-
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3),
71-
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3),
72-
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3),
73-
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3),
74-
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3),
75-
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3),
76-
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
70+
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
71+
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
72+
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
73+
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
74+
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
75+
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
76+
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
77+
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
7778
sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
7879
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
7980
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -64,47 +64,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6464
const FLOAT_TYPE sc6 = scale8_f.z;
6565
const FLOAT_TYPE sc7 = scale8_f.w;
6666

67-
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
68-
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
67+
const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
68+
const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
6969

7070
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
7171
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
7272
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
7373
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
7474

75-
uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
75+
const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
7676

77-
uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
78-
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
79-
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
80-
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
77+
const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
78+
const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
79+
const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
80+
const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
8181

8282
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
8383
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
8484
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
8585
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
8686

87-
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
88-
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
89-
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
90-
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
91-
92-
const uint32_t q4_0 = qs0_16_lo4.x;
93-
const uint32_t q4_1 = qs0_16_lo4.y;
94-
const uint32_t q4_2 = qs0_16_lo4.z;
95-
const uint32_t q4_3 = qs0_16_lo4.w;
96-
const uint32_t q4_4 = qs0_16_hi4.x;
97-
const uint32_t q4_5 = qs0_16_hi4.y;
98-
const uint32_t q4_6 = qs0_16_hi4.z;
99-
const uint32_t q4_7 = qs0_16_hi4.w;
100-
const uint32_t q4_8 = qs64_80_lo4.x;
101-
const uint32_t q4_9 = qs64_80_lo4.y;
102-
const uint32_t q4_10 = qs64_80_lo4.z;
103-
const uint32_t q4_11 = qs64_80_lo4.w;
104-
const uint32_t q4_12 = qs64_80_hi4.x;
105-
const uint32_t q4_13 = qs64_80_hi4.y;
106-
const uint32_t q4_14 = qs64_80_hi4.z;
107-
const uint32_t q4_15 = qs64_80_hi4.w;
87+
const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
88+
const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
89+
const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
90+
const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
91+
92+
const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
93+
const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
94+
const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
95+
const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
96+
const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
97+
const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
98+
const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
99+
const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
100+
const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
101+
const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
102+
const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
103+
const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
104+
const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
105+
const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
106+
const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
107+
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
108108

109109
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
110110
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];

0 commit comments

Comments
 (0)