@@ -21,9 +21,13 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2121    const uint itid8 = itid%8;
2222
2323    const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128...
24+     const uint v_im4 = v_im*4;
2425    const uint v_in = itid - 8*v_im;                        // 0...7
2526
26-     const uint8_t m = uint8_t(1 << (4 * v_im));
27+     const uint32_t m = 0x01010101 << (4 * v_im);
28+     uint32_t hm_m[4];
29+     [[unroll]] for (uint j = 0; j < 4; ++j)
30+         hm_m[j] = m << j;
2731
2832    const uint l0 = 2*v_in;                                 // 0...15
2933    const uint q_offset = 32*v_im + l0;
@@ -44,7 +48,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4448            const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
4549            const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
4650
47-             sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im ) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im  + 2*(itid8/4)) & 0x3) << 4)) - 32);
51+             sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4 ) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4  + 2*(itid8/4)) & 0x3) << 4)) - 32);
4852            barrier();
4953
5054            // 0, 1, 16, 17
@@ -55,8 +59,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5559            const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
5660            const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
5761
58-             const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in]));
59-             const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8]));
62+             const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
63+             const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> (    v_im4)) << 2));
64+             const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
65+             const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
66+             const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
6067
6168            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
6269
@@ -71,14 +78,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
7178
7279                FLOAT_TYPE sum = FLOAT_TYPE(0.0);
7380                [[unroll]] for (int l = 0; l < 2; ++l) {
74-                     sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - FLOAT_TYPE((( hmk0[l] & (m     )) != 0) ? 0 : 4) ,
75-                           fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m     )) != 0) ? 0 : 4) ,
76-                           fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4) ,
77-                           fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4) ,
78-                           fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4) ,
79-                           fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4) ,
80-                           fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4) ,
81-                           fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4) , sum))))))));
81+                     sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ] ,
82+                           fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2] ,
83+                           fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ] ,
84+                           fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2] ,
85+                           fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ] ,
86+                           fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2] ,
87+                           fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ] ,
88+                           fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2] , sum))))))));
8289                }
8390                temp[j][n] = fma(d, sum, temp[j][n]);
8491            }
0 commit comments