55
66layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77
8- shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
9- shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
8+ shared FLOAT_TYPE sccache1[2][ BLOCK_SIZE/16][16];
9+ shared FLOAT_TYPE sccache2[2][ BLOCK_SIZE/16][16];
1010
1111FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+ uint csel = 0;
1213
1314void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1415 const uint y_idx = i * QUANT_K + y_offset;
1516
1617 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
1718 const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
19+ csel ^= 1;
1820
19- barrier();
2021 if (!all_threads) { // when we don't have enough blocks to use all threads
2122 if (i < num_blocks_per_row) {
2223 const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
23- sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
24- sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
24+ sccache1[csel][ ix][itid] = FLOAT_TYPE(scale & 0xF);
25+ sccache2[csel][ ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
2526 }
2627 barrier();
2728
2829 if (i >= num_blocks_per_row)
2930 continue;
3031 } else {
3132 const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
32- sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
33- sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
33+ sccache1[csel][ ix][itid] = FLOAT_TYPE(scale & 0xF);
34+ sccache2[csel][ ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
3435 barrier();
3536 }
3637
@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
5758 FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
5859 FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
5960 [[unroll]] for (int l = 0; l < 2; ++l) {
60- sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
61- fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
62- fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
63- fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
64- fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
65- fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
66- fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
67- fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
68- sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
69- fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
70- fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
71- fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
72- fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
73- fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
74- fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
75- fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
61+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ ix][ 8*v_im] * qs_u32_0[l ],
62+ fma(FLOAT_TYPE(b16[l]), sccache1[csel][ ix][1 + 8*v_im] * qs_u32_0[l+2],
63+ fma(FLOAT_TYPE(b32[l]), sccache1[csel][ ix][2 + 8*v_im] * qs_u32_2[l ],
64+ fma(FLOAT_TYPE(b48[l]), sccache1[csel][ ix][3 + 8*v_im] * qs_u32_2[l+2],
65+ fma(FLOAT_TYPE(b64[l]), sccache1[csel][ ix][4 + 8*v_im] * qs_u32_4[l ],
66+ fma(FLOAT_TYPE(b80[l]), sccache1[csel][ ix][5 + 8*v_im] * qs_u32_4[l+2],
67+ fma(FLOAT_TYPE(b96[l]), sccache1[csel][ ix][6 + 8*v_im] * qs_u32_6[l ],
68+ fma(FLOAT_TYPE(b112[l]), sccache1[csel][ ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
69+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ ix][ 8*v_im],
70+ fma(FLOAT_TYPE(b16[l]), sccache2[csel][ ix][1 + 8*v_im],
71+ fma(FLOAT_TYPE(b32[l]), sccache2[csel][ ix][2 + 8*v_im],
72+ fma(FLOAT_TYPE(b48[l]), sccache2[csel][ ix][3 + 8*v_im],
73+ fma(FLOAT_TYPE(b64[l]), sccache2[csel][ ix][4 + 8*v_im],
74+ fma(FLOAT_TYPE(b80[l]), sccache2[csel][ ix][5 + 8*v_im],
75+ fma(FLOAT_TYPE(b96[l]), sccache2[csel][ ix][6 + 8*v_im],
76+ fma(FLOAT_TYPE(b112[l]), sccache2[csel][ ix][7 + 8*v_im], sum2))))))));
7677 }
7778 temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
7879 }
0 commit comments