@@ -34,9 +34,6 @@ void main() {
3434 const uint q_offset = 32*v_im + l0;
3535 const uint y_offset = 64*v_im + l0;
3636
37- const uint8_t hm1 = uint8_t(1 << (2*v_im));
38- const uint8_t hm2 = uint8_t(hm1 << 4);
39-
4037 FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
4138
4239 [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
@@ -71,6 +68,18 @@ void main() {
7168 uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
7269 uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
7370
71+ uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
72+
73+ uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
74+ uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
75+ uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
76+ uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
77+
78+ qs0_16_u32_lo4 += qs0_16_lo4_offset16;
79+ qs0_16_u32_hi4 += qs0_16_hi4_offset16;
80+ qs64_80_u32_lo4 += qs64_80_lo4_offset16;
81+ qs64_80_u32_hi4 += qs64_80_hi4_offset16;
82+
7483 uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
7584 uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
7685 uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
@@ -102,31 +111,26 @@ void main() {
102111 B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
103112 B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
104113
105- uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
106- uint32_t qh1 = qh0 >> 8;
107- uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
108- uint32_t qh17 = qh16 >> 8;
109-
110114 const FLOAT_TYPE sx =
111- fma(FLOAT_TYPE(by10.x), ( q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)) ,
112- fma(FLOAT_TYPE(by10.y), ( q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)) ,
113- fma(FLOAT_TYPE(by116.x), ( q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)) ,
114- FLOAT_TYPE(by116.y) * ( q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)) )));
115+ fma(FLOAT_TYPE(by10.x), q4_0,
116+ fma(FLOAT_TYPE(by10.y), q4_1,
117+ fma(FLOAT_TYPE(by116.x), q4_2,
118+ FLOAT_TYPE(by116.y) * q4_3)));
115119 const FLOAT_TYPE sy =
116- fma(FLOAT_TYPE(by132.x), ( q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)) ,
117- fma(FLOAT_TYPE(by132.y), ( q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)) ,
118- fma(FLOAT_TYPE(by148.x), ( q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)) ,
119- FLOAT_TYPE(by148.y) * ( q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)) )));
120+ fma(FLOAT_TYPE(by132.x), q4_4,
121+ fma(FLOAT_TYPE(by132.y), q4_5,
122+ fma(FLOAT_TYPE(by148.x), q4_6,
123+ FLOAT_TYPE(by148.y) * q4_7)));
120124 const FLOAT_TYPE sz =
121- fma(FLOAT_TYPE(by20.x), ( q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)) ,
122- fma(FLOAT_TYPE(by20.y), ( q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)) ,
123- fma(FLOAT_TYPE(by216.x), ( q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)) ,
124- FLOAT_TYPE(by216.y) * ( q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)) )));
125+ fma(FLOAT_TYPE(by20.x), q4_8,
126+ fma(FLOAT_TYPE(by20.y), q4_9,
127+ fma(FLOAT_TYPE(by216.x), q4_10,
128+ FLOAT_TYPE(by216.y) * q4_11)));
125129 const FLOAT_TYPE sw =
126- fma(FLOAT_TYPE(by232.x), ( q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)) ,
127- fma(FLOAT_TYPE(by232.y), ( q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)) ,
128- fma(FLOAT_TYPE(by248.x), ( q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)) ,
129- FLOAT_TYPE(by248.y) * ( q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)) )));
130+ fma(FLOAT_TYPE(by232.x), q4_12,
131+ fma(FLOAT_TYPE(by232.y), q4_13,
132+ fma(FLOAT_TYPE(by248.x), q4_14,
133+ FLOAT_TYPE(by248.y) * q4_15)));
130134 const FLOAT_TYPE smin =
131135 fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
132136 fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
0 commit comments