@@ -7,34 +7,50 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77
88FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
99
10- void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
11- const uint y_idx = i * QUANT_K + 32 * ib32;
12-
13- uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
14- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
15- const float d = float(data_a[ibi].d);
16- const uint qh = data_a[ibi].qh[ib32];
17- const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
18- const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
19-
20- [[unroll]] for (uint l = 0; l < 4; ++l) {
21- const uint qs = data_a[ibi].qs[4 * ib32 + l];
22- const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
23- const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
24-
25- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
26- vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
27- vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
28-
29- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
30- [[unroll]] for (int k = 0; k < 4; ++k) {
31- sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
32- fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
33- }
34- temp[j][n] = fma(dl, sum, temp[j][n]);
10+ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i,
11+ const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
12+ const uint y_idx_base = i * QUANT_K + 32 * ib32;
13+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
14+ const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
15+ [[unroll]] for (uint l = 0; l < 4; ++l) {
16+ const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]);
17+ const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
18+
19+ // index for data_a
20+ uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
21+
22+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
23+ const float d = float(data_a[ibi].d);
24+ const uint qh = data_a[ibi].qh[ib32];
25+
26+ const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
27+ const uint qs = data_a[ibi].qs[4 * ib32 + l];
28+ const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
29+ const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]);
30+
31+ const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
32+ const vec4 delta_v = vec4(delta_val);
33+ const vec4 fbits0 = vec4(
34+ float(bitfieldExtract(grid, 0, 2)),
35+ float(bitfieldExtract(grid, 2, 2)),
36+ float(bitfieldExtract(grid, 4, 2)),
37+ float(bitfieldExtract(grid, 6, 2))
38+ );
39+ const vec4 fbits1 = vec4(
40+ float(bitfieldExtract(grid, 8, 2)),
41+ float(bitfieldExtract(grid, 10, 2)),
42+ float(bitfieldExtract(grid, 12, 2)),
43+ float(bitfieldExtract(grid, 14, 2))
44+ );
45+
46+ vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0));
47+ sum_v = fma(b_val_1, fbits1 + delta_v, sum_v);
48+ FLOAT_TYPE sum = dot(sum_v, vec4(1.0));
49+
50+ temp[j][n] = fma(dl, sum, temp[j][n]);
51+ ibi += num_blocks_per_row;
3552 }
3653 }
37- ibi += num_blocks_per_row;
3854 }
3955}
4056
0 commit comments