Skip to content

Commit caf7ffa

Browse files
committed
q3_k optimizations
1 parent 6c52f3b commit caf7ffa

File tree

5 files changed

+24
-13
lines changed

5 files changed

+24
-13
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2121
const uint itid8 = itid%8;
2222

2323
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
24-
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
24+
const uint v_in = itid - 8*v_im; // 0...7
2525

2626
const uint l0 = 2*v_in; // 0...15
2727
const uint q_offset = 32*v_im + l0;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2121
const uint itid8 = itid%8;
2222

2323
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
24-
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
24+
const uint v_in = itid - 8*v_im; // 0...7
2525

2626
const uint8_t m = uint8_t(1 << (4 * v_im));
2727

@@ -47,6 +47,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4747
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32);
4848
barrier();
4949

50+
// 0, 1, 16, 17
51+
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
52+
qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
53+
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
54+
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
55+
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
56+
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
57+
58+
const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in]));
59+
const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8]));
60+
5061
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
5162

5263
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
@@ -60,14 +71,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6071

6172
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
6273
[[unroll]] for (int l = 0; l < 2; ++l) {
63-
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)),
64-
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)),
65-
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
66-
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
67-
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
68-
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
69-
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
70-
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
74+
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4),
75+
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4),
76+
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4),
77+
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4),
78+
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4),
79+
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4),
80+
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4),
81+
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum))))))));
7182
}
7283
temp[j][n] = fma(d, sum, temp[j][n]);
7384
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1919
const uint ix = tid/16;
2020

2121
const uint il = itid/4; // 0...3
22-
const uint ir = itid - 4*il; // 0...7 or 0...3
22+
const uint ir = itid - 4*il; // 0...3
2323
const uint n = 4;
2424

2525
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1919
const uint ix = tid/16;
2020

2121
const uint il = itid/4; // 0...3
22-
const uint ir = itid - 4*il; // 0...7 or 0...3
22+
const uint ir = itid - 4*il; // 0...3
2323

2424
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
2525
const uint v_in = il % 2;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void compute_outputs(const uint first_row, const uint num_rows) {
8787
const uint ix = tid/16;
8888

8989
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
90-
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
90+
const uint v_in = itid - 8*v_im; // 0...7
9191

9292
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
9393
const uint is = v_in / 4;

0 commit comments

Comments
 (0)