@@ -8,7 +8,73 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88
99shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
1010
11- void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
11+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+
13+ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
14+ const uint y_idx = i * QUANT_K + y_offset;
15+
16+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
17+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
18+
19+ if (!all_threads) { // when we don't have enough blocks to use all threads
20+ if (i < num_blocks_per_row)
21+ sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
22+ barrier();
23+
24+ if (i >= num_blocks_per_row)
25+ continue;
26+ }
27+
28+ const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
29+ const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
30+
31+ const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
32+ const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
33+ const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
34+ const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
35+
36+ const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
37+ const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
38+ const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
39+ const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
40+ const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
41+
42+ const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
43+ const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
44+ const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
45+ const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
46+
47+ const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
48+ const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
49+ const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
50+ const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
51+
52+ if (all_threads) {
53+ sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
54+ barrier();
55+ }
56+
57+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
58+
59+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
60+ B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
61+ B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
62+ B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
63+ B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
64+
65+ FLOAT_TYPE sum[4] = {0, 0, 0, 0};
66+ [[unroll]] for (uint l = 0; l < 4; ++l) {
67+ sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
68+ sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
69+ sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
70+ sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
71+ }
72+ temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
73+ }
74+ }
75+ }
76+
77+ void compute_outputs(const uint first_row, const uint num_rows) {
1278 uint a_offset, b_offset, d_offset;
1379 get_offsets(a_offset, b_offset, d_offset);
1480
@@ -31,65 +97,19 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
3197 const uint s_offset = 8*v_im + is;
3298 const uint y_offset = 128*v_im + l0;
3399
34- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
35-
36100 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
37101 [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38102 temp[j][i] = FLOAT_TYPE(0);
39103 }
40104 }
41105
42- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
43- const uint y_idx = i * QUANT_K + y_offset;
44-
45- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
46- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
47- const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
48-
49- sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
50- barrier();
51-
52- uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
53- uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
54-
55- uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
56- uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
57- uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
58- uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
59-
60- uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
61- uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
62- uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
63- uint32_t qh4_u32 = (qh_u32 & 0x30303030);
64- uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
65-
66- uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
67- uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
68- uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
69- uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
70-
71- uvec4 q0 = uvec4(unpack8(q0_u32));
72- uvec4 q1 = uvec4(unpack8(q1_u32));
73- uvec4 q2 = uvec4(unpack8(q2_u32));
74- uvec4 q3 = uvec4(unpack8(q3_u32));
75-
76- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
77- B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
78- B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
79- B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
80- B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
81-
82- FLOAT_TYPE sum[4] = {0, 0, 0, 0};
83- [[unroll]] for (uint l = 0; l < 4; ++l) {
84- sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]);
85- sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]);
86- sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]);
87- sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]);
88- }
89- temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
90- }
91- }
106+ const uint nbr_par_th = num_blocks_per_row%it_size;
107+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
108+ uint i0 = 0;
109+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) {
110+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
92111 }
112+ calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
93113
94114 reduce_result(temp, d_offset, first_row, num_rows, tid);
95115}
0 commit comments