@@ -46,21 +46,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4646 const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
4747 const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4848
49- uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ];
50- uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2];
51- uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4];
52- uvec2 scale0 = uvec2(unpack8(scale0_u16));
53- uvec2 scale4 = uvec2(unpack8(scale4_u16));
54- uvec2 scale8 = uvec2(unpack8(scale8_u16));
55-
56- const uint32_t sc0 = ( scale0.x & 0x3f);
57- const uint32_t sc1 = ( scale0.y & 0x3f);
58- const uint32_t sc2 = ( scale4.x & 0x3f);
59- const uint32_t sc3 = ( scale4.y & 0x3f);
60- const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
61- const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
62- const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
63- const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
49+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
50+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
51+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
52+
53+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
54+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
55+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
56+ const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
57+
58+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
59+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
60+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
61+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
62+ const FLOAT_TYPE sc4 = scale8_f.x;
63+ const FLOAT_TYPE sc5 = scale8_f.y;
64+ const FLOAT_TYPE sc6 = scale8_f.z;
65+ const FLOAT_TYPE sc7 = scale8_f.w;
6466
6567 uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
6668 uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
0 commit comments