Skip to content

Commit ced706d

Browse files
committed
better q4_k scales
1 parent dc6afb7 commit ced706d

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,55 +43,57 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4343

4444
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
4545
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
46-
f16vec2 d = data_a[ib0 + i].d;
46+
const f16vec2 d = data_a[ib0 + i].d;
4747
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
4848
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4949

50-
uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ];
51-
uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2];
52-
uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4];
53-
uvec2 scale0 = uvec2(unpack8(scale0_u16));
54-
uvec2 scale4 = uvec2(unpack8(scale4_u16));
55-
uvec2 scale8 = uvec2(unpack8(scale8_u16));
56-
57-
const uint32_t sc0 = ( scale0.x & 0x3f);
58-
const uint32_t sc1 = ( scale0.y & 0x3f);
59-
const uint32_t sc2 = ( scale4.x & 0x3f);
60-
const uint32_t sc3 = ( scale4.y & 0x3f);
61-
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
62-
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
63-
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
64-
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
65-
66-
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
67-
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
68-
69-
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
70-
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
71-
uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
72-
uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
73-
74-
uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
75-
uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
76-
uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
77-
uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
78-
79-
const uint32_t q4_0 = qs0_lo4.x;
80-
const uint32_t q4_1 = qs0_lo4.y;
81-
const uint32_t q4_2 = qs0_lo4.z;
82-
const uint32_t q4_3 = qs0_lo4.w;
83-
const uint32_t q4_4 = qs0_hi4.x;
84-
const uint32_t q4_5 = qs0_hi4.y;
85-
const uint32_t q4_6 = qs0_hi4.z;
86-
const uint32_t q4_7 = qs0_hi4.w;
87-
const uint32_t q4_8 = qs64_lo4.x;
88-
const uint32_t q4_9 = qs64_lo4.y;
89-
const uint32_t q4_10 = qs64_lo4.z;
90-
const uint32_t q4_11 = qs64_lo4.w;
91-
const uint32_t q4_12 = qs64_hi4.x;
92-
const uint32_t q4_13 = qs64_hi4.y;
93-
const uint32_t q4_14 = qs64_hi4.z;
94-
const uint32_t q4_15 = qs64_hi4.w;
50+
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
51+
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
52+
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
53+
54+
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
55+
const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2;
56+
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f));
57+
const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h));
58+
59+
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
60+
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
61+
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
62+
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
63+
const FLOAT_TYPE sc4 = scale8_f.x;
64+
const FLOAT_TYPE sc5 = scale8_f.y;
65+
const FLOAT_TYPE sc6 = scale8_f.z;
66+
const FLOAT_TYPE sc7 = scale8_f.w;
67+
68+
const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
69+
const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
70+
71+
const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
72+
const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
73+
const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
74+
const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
75+
76+
const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
77+
const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
78+
const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
79+
const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
80+
81+
const FLOAT_TYPE q4_0 = qs0_lo4.x;
82+
const FLOAT_TYPE q4_1 = qs0_lo4.y;
83+
const FLOAT_TYPE q4_2 = qs0_lo4.z;
84+
const FLOAT_TYPE q4_3 = qs0_lo4.w;
85+
const FLOAT_TYPE q4_4 = qs0_hi4.x;
86+
const FLOAT_TYPE q4_5 = qs0_hi4.y;
87+
const FLOAT_TYPE q4_6 = qs0_hi4.z;
88+
const FLOAT_TYPE q4_7 = qs0_hi4.w;
89+
const FLOAT_TYPE q4_8 = qs64_lo4.x;
90+
const FLOAT_TYPE q4_9 = qs64_lo4.y;
91+
const FLOAT_TYPE q4_10 = qs64_lo4.z;
92+
const FLOAT_TYPE q4_11 = qs64_lo4.w;
93+
const FLOAT_TYPE q4_12 = qs64_hi4.x;
94+
const FLOAT_TYPE q4_13 = qs64_hi4.y;
95+
const FLOAT_TYPE q4_14 = qs64_hi4.z;
96+
const FLOAT_TYPE q4_15 = qs64_hi4.w;
9597

9698
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
9799
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];

0 commit comments

Comments
 (0)