@@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2121 const uint itid8 = itid%8;
2222
2323 const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
24- const uint v_in = itid - 8*v_im; // 0...15 or 0... 7
24+ const uint v_in = itid - 8*v_im; // 0...7
2525
2626 const uint8_t m = uint8_t(1 << (4 * v_im));
2727
@@ -47,6 +47,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4747 sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32);
4848 barrier();
4949
50+ // 0, 1, 16, 17
51+ uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
52+ qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
53+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
54+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
55+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
56+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
57+
58+ const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in]));
59+ const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8]));
60+
5061 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
5162
5263 B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
@@ -60,14 +71,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6071
6172 FLOAT_TYPE sum = FLOAT_TYPE(0.0);
6273 [[unroll]] for (int l = 0; l < 2; ++l) {
63- sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4) ),
64- fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16 ] & (m )) != 0) ? 0 : 4) ),
65- fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4) ),
66- fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16 ] & (m << 1)) != 0) ? 0 : 4) ),
67- fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4) ),
68- fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16 ] & (m << 2)) != 0) ? 0 : 4) ),
69- fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4) ),
70- fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16 ] & (m << 3)) != 0) ? 0 : 4) ), sum))))))));
74+ sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[ l ] - FLOAT_TYPE ((( hmk0[l ] & (m )) != 0) ? 0 : 4),
75+ fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE (((hmk16[l ] & (m )) != 0) ? 0 : 4),
76+ fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[ l ] - FLOAT_TYPE ((( hmk0[l ] & (m << 1)) != 0) ? 0 : 4),
77+ fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE (((hmk16[l ] & (m << 1)) != 0) ? 0 : 4),
78+ fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[ l ] - FLOAT_TYPE ((( hmk0[l ] & (m << 2)) != 0) ? 0 : 4),
79+ fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE (((hmk16[l ] & (m << 2)) != 0) ? 0 : 4),
80+ fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[ l ] - FLOAT_TYPE ((( hmk0[l ] & (m << 3)) != 0) ? 0 : 4),
81+ fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE (((hmk16[l ] & (m << 3)) != 0) ? 0 : 4), sum))))))));
7182 }
7283 temp[j][n] = fma(d, sum, temp[j][n]);
7384 }
0 commit comments