Skip to content

Commit 6145fc7

Browse files
committed
q2_k separate out
1 parent 973bc40 commit 6145fc7

File tree

2 files changed

+76
-58
lines changed

2 files changed

+76
-58
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

88
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16];
99

10+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
11+
12+
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint itid8, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
13+
const uint y_idx = i * QUANT_K + y_offset;
14+
15+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
16+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17+
18+
if (!all_threads) { // when we don't have enough blocks to use all threads
19+
if (i < num_blocks_per_row) {
20+
sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes
21+
sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes
22+
}
23+
barrier();
24+
25+
if (i >= num_blocks_per_row)
26+
continue;
27+
} else {
28+
sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF);
29+
sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF);
30+
barrier();
31+
}
32+
33+
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
34+
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
35+
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
36+
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
37+
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
38+
39+
const f16vec2 d = data_a[ib0 + i].d;
40+
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
41+
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
42+
43+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
44+
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
45+
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
46+
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
47+
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
48+
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
49+
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
50+
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
51+
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
52+
53+
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
54+
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
55+
[[unroll]] for (int l = 0; l < 2; ++l) {
56+
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
57+
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
58+
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
59+
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
60+
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
61+
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
62+
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
63+
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
64+
sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
65+
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
66+
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
67+
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11],
68+
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12],
69+
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13],
70+
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14],
71+
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2))))))));
72+
}
73+
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
74+
}
75+
}
76+
}
77+
1078
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1179
uint a_offset, b_offset, d_offset;
1280
get_offsets(a_offset, b_offset, d_offset);
@@ -27,68 +95,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2795
const uint q_offset = 32*v_im + l0;
2896
const uint y_offset = 128*v_im + l0;
2997

30-
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
31-
3298
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
3399
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
34100
temp[j][i] = FLOAT_TYPE(0);
35101
}
36102
}
37103

38-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
39-
const uint y_idx = i * QUANT_K + y_offset;
40-
41-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
42-
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43-
44-
sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
45-
sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
46-
barrier();
47-
48-
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
49-
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
50-
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
51-
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
52-
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
53-
54-
const f16vec2 d = data_a[ib0 + i].d;
55-
const FLOAT_TYPE dall = d.x;
56-
const FLOAT_TYPE dmin = d.y;
57-
58-
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
59-
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
60-
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
61-
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
62-
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
63-
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
64-
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
65-
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
66-
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
67-
68-
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
69-
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
70-
[[unroll]] for (int l = 0; l < 2; ++l) {
71-
sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
72-
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
73-
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
74-
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
75-
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
76-
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
77-
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
78-
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
79-
sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
80-
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
81-
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
82-
fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11],
83-
fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12],
84-
fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13],
85-
fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14],
86-
fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2))))))));
87-
}
88-
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
89-
}
90-
}
91-
}
104+
const uint nbr_par_th = num_blocks_per_row%it_size;
105+
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
106+
uint i0 = 0;
107+
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
108+
calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
109+
calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
92110

93111
reduce_result(temp, d_offset, first_row, num_rows, tid);
94112
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
1616

1717
if (!all_threads) { // when we don't have enough blocks to use all threads
1818
if (i < num_blocks_per_row)
19-
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32);
19+
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
2020
barrier();
2121

2222
if (i >= num_blocks_per_row)
@@ -38,7 +38,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
3838
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
3939

4040
if (all_threads) {
41-
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32);
41+
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
4242
barrier();
4343
}
4444

0 commit comments

Comments
 (0)