Skip to content

Commit 91f1d9c

Browse files
committed
better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations
1 parent 6f5d62b commit 91f1d9c

File tree

1 file changed

+73
-53
lines changed

1 file changed

+73
-53
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,73 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

99
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
1010

11-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
11+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+
13+
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
14+
const uint y_idx = i * QUANT_K + y_offset;
15+
16+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
17+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
18+
19+
if (!all_threads) { // when we don't have enough blocks to use all threads
20+
if (i < num_blocks_per_row)
21+
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
22+
barrier();
23+
24+
if (i >= num_blocks_per_row)
25+
continue;
26+
}
27+
28+
const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
29+
const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
30+
31+
const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
32+
const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
33+
const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
34+
const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
35+
36+
const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
37+
const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
38+
const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
39+
const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
40+
const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
41+
42+
const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
43+
const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
44+
const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
45+
const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
46+
47+
const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
48+
const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
49+
const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
50+
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
51+
52+
if (all_threads) {
53+
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
54+
barrier();
55+
}
56+
57+
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
58+
59+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
60+
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
61+
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
62+
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
63+
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
64+
65+
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
66+
[[unroll]] for (uint l = 0; l < 4; ++l) {
67+
sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
68+
sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
69+
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
70+
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
71+
}
72+
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
73+
}
74+
}
75+
}
76+
77+
void compute_outputs(const uint first_row, const uint num_rows) {
1278
uint a_offset, b_offset, d_offset;
1379
get_offsets(a_offset, b_offset, d_offset);
1480

@@ -31,65 +97,19 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
3197
const uint s_offset = 8*v_im + is;
3298
const uint y_offset = 128*v_im + l0;
3399

34-
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
35-
36100
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
37101
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38102
temp[j][i] = FLOAT_TYPE(0);
39103
}
40104
}
41105

42-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
43-
const uint y_idx = i * QUANT_K + y_offset;
44-
45-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
46-
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
47-
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
48-
49-
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
50-
barrier();
51-
52-
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
53-
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
54-
55-
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
56-
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
57-
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
58-
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
59-
60-
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
61-
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
62-
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
63-
uint32_t qh4_u32 = (qh_u32 & 0x30303030);
64-
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
65-
66-
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
67-
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
68-
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
69-
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
70-
71-
uvec4 q0 = uvec4(unpack8(q0_u32));
72-
uvec4 q1 = uvec4(unpack8(q1_u32));
73-
uvec4 q2 = uvec4(unpack8(q2_u32));
74-
uvec4 q3 = uvec4(unpack8(q3_u32));
75-
76-
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
77-
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
78-
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
79-
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
80-
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
81-
82-
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
83-
[[unroll]] for (uint l = 0; l < 4; ++l) {
84-
sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]);
85-
sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]);
86-
sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]);
87-
sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]);
88-
}
89-
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
90-
}
91-
}
106+
const uint nbr_par_th = num_blocks_per_row%it_size;
107+
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
108+
uint i0 = 0;
109+
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size) {
110+
calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
92111
}
112+
calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
93113

94114
reduce_result(temp, d_offset, first_row, num_rows, tid);
95115
}

0 commit comments

Comments
 (0)