@@ -43,55 +43,57 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4343
4444 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
4545 const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
46- f16vec2 d = data_a[ib0 + i].d;
46+ const f16vec2 d = data_a[ib0 + i].d;
4747 const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
4848 const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4949
50- uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ];
51- uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2];
52- uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4];
53- uvec2 scale0 = uvec2(unpack8(scale0_u16));
54- uvec2 scale4 = uvec2(unpack8(scale4_u16));
55- uvec2 scale8 = uvec2(unpack8(scale8_u16));
56-
57- const uint32_t sc0 = ( scale0.x & 0x3f);
58- const uint32_t sc1 = ( scale0.y & 0x3f);
59- const uint32_t sc2 = ( scale4.x & 0x3f);
60- const uint32_t sc3 = ( scale4.y & 0x3f);
61- const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
62- const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
63- const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
64- const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
65-
66- uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
67- uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
68-
69- uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
70- uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
71- uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
72- uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
73-
74- uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
75- uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
76- uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
77- uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
78-
79- const uint32_t q4_0 = qs0_lo4.x;
80- const uint32_t q4_1 = qs0_lo4.y;
81- const uint32_t q4_2 = qs0_lo4.z;
82- const uint32_t q4_3 = qs0_lo4.w;
83- const uint32_t q4_4 = qs0_hi4.x;
84- const uint32_t q4_5 = qs0_hi4.y;
85- const uint32_t q4_6 = qs0_hi4.z;
86- const uint32_t q4_7 = qs0_hi4.w;
87- const uint32_t q4_8 = qs64_lo4.x;
88- const uint32_t q4_9 = qs64_lo4.y;
89- const uint32_t q4_10 = qs64_lo4.z;
90- const uint32_t q4_11 = qs64_lo4.w;
91- const uint32_t q4_12 = qs64_hi4.x;
92- const uint32_t q4_13 = qs64_hi4.y;
93- const uint32_t q4_14 = qs64_hi4.z;
94- const uint32_t q4_15 = qs64_hi4.w;
50+ const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
51+ const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
52+ const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
53+
54+ const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
55+ const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2;
56+ const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f));
57+ const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h));
58+
59+ const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
60+ const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
61+ const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
62+ const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
63+ const FLOAT_TYPE sc4 = scale8_f.x;
64+ const FLOAT_TYPE sc5 = scale8_f.y;
65+ const FLOAT_TYPE sc6 = scale8_f.z;
66+ const FLOAT_TYPE sc7 = scale8_f.w;
67+
68+ const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
69+ const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
70+
71+ const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
72+ const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
73+ const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
74+ const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
75+
76+ const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
77+ const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
78+ const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
79+ const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
80+
81+ const FLOAT_TYPE q4_0 = qs0_lo4.x;
82+ const FLOAT_TYPE q4_1 = qs0_lo4.y;
83+ const FLOAT_TYPE q4_2 = qs0_lo4.z;
84+ const FLOAT_TYPE q4_3 = qs0_lo4.w;
85+ const FLOAT_TYPE q4_4 = qs0_hi4.x;
86+ const FLOAT_TYPE q4_5 = qs0_hi4.y;
87+ const FLOAT_TYPE q4_6 = qs0_hi4.z;
88+ const FLOAT_TYPE q4_7 = qs0_hi4.w;
89+ const FLOAT_TYPE q4_8 = qs64_lo4.x;
90+ const FLOAT_TYPE q4_9 = qs64_lo4.y;
91+ const FLOAT_TYPE q4_10 = qs64_lo4.z;
92+ const FLOAT_TYPE q4_11 = qs64_lo4.w;
93+ const FLOAT_TYPE q4_12 = qs64_hi4.x;
94+ const FLOAT_TYPE q4_13 = qs64_hi4.y;
95+ const FLOAT_TYPE q4_14 = qs64_hi4.z;
96+ const FLOAT_TYPE q4_15 = qs64_hi4.w;
9597
9698 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
9799 B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
0 commit comments