@@ -7,6 +7,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77
88shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16];
99
10+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
11+
12+ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint itid8, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
13+ const uint y_idx = i * QUANT_K + y_offset;
14+
15+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
16+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
17+
18+ if (!all_threads) { // when we don't have enough blocks to use all threads
19+ if (i < num_blocks_per_row) {
20+ sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes
21+ sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes
22+ }
23+ barrier();
24+
25+ if (i >= num_blocks_per_row)
26+ continue;
27+ } else {
28+ sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF);
29+ sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF);
30+ barrier();
31+ }
32+
33+ const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
34+ const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
35+ const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
36+ const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
37+ const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
38+
39+ const f16vec2 d = data_a[ib0 + i].d;
40+ const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
41+ const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
42+
43+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
44+ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
45+ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
46+ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
47+ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
48+ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
49+ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
50+ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
51+ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
52+
53+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
54+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
55+ [[unroll]] for (int l = 0; l < 2; ++l) {
56+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
57+ fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
58+ fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
59+ fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
60+ fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
61+ fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
62+ fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
63+ fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
64+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
65+ fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
66+ fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
67+ fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11],
68+ fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12],
69+ fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13],
70+ fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14],
71+ fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2))))))));
72+ }
73+ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
74+ }
75+ }
76+ }
77+
1078void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1179 uint a_offset, b_offset, d_offset;
1280 get_offsets(a_offset, b_offset, d_offset);
@@ -27,68 +95,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2795 const uint q_offset = 32*v_im + l0;
2896 const uint y_offset = 128*v_im + l0;
2997
30- FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
31-
3298 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
3399 [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
34100 temp[j][i] = FLOAT_TYPE(0);
35101 }
36102 }
37103
38- [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
39- const uint y_idx = i * QUANT_K + y_offset;
40-
41- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
42- const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43-
44- sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
45- sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
46- barrier();
47-
48- const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
49- const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
50- const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
51- const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
52- const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
53-
54- const f16vec2 d = data_a[ib0 + i].d;
55- const FLOAT_TYPE dall = d.x;
56- const FLOAT_TYPE dmin = d.y;
57-
58- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
59- B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
60- B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
61- B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
62- B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
63- B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
64- B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
65- B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
66- B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
67-
68- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
69- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
70- [[unroll]] for (int l = 0; l < 2; ++l) {
71- sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ],
72- fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2],
73- fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ],
74- fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2],
75- fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ],
76- fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2],
77- fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ],
78- fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1))))))));
79- sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
80- fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
81- fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
82- fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11],
83- fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12],
84- fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13],
85- fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14],
86- fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2))))))));
87- }
88- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
89- }
90- }
91- }
104+ const uint nbr_par_th = num_blocks_per_row%it_size;
105+ const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
106+ uint i0 = 0;
107+ [[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
108+ calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
109+ calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
92110
93111 reduce_result(temp, d_offset, first_row, num_rows, tid);
94112}
0 commit comments