Skip to content

Commit 6f5d62b

Browse files
committed
q5_k
1 parent cdf70cf commit 6f5d62b

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5252
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
5353

5454
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
55-
const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2;
56-
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f));
57-
const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h));
55+
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
56+
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
57+
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
5858

5959
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
6060
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4646
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
4747
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4848

49-
uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ];
50-
uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2];
51-
uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4];
52-
uvec2 scale0 = uvec2(unpack8(scale0_u16));
53-
uvec2 scale4 = uvec2(unpack8(scale4_u16));
54-
uvec2 scale8 = uvec2(unpack8(scale8_u16));
55-
56-
const uint32_t sc0 = ( scale0.x & 0x3f);
57-
const uint32_t sc1 = ( scale0.y & 0x3f);
58-
const uint32_t sc2 = ( scale4.x & 0x3f);
59-
const uint32_t sc3 = ( scale4.y & 0x3f);
60-
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
61-
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
62-
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
63-
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
49+
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
50+
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
51+
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
52+
53+
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
54+
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
55+
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
56+
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
57+
58+
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
59+
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
60+
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
61+
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
62+
const FLOAT_TYPE sc4 = scale8_f.x;
63+
const FLOAT_TYPE sc5 = scale8_f.y;
64+
const FLOAT_TYPE sc6 = scale8_f.z;
65+
const FLOAT_TYPE sc7 = scale8_f.w;
6466

6567
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
6668
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);

0 commit comments

Comments
 (0)