55
66layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77
8- shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
9- shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
8+ shared FLOAT_TYPE sccache1[2][ BLOCK_SIZE/16][16];
9+ shared FLOAT_TYPE sccache2[2][ BLOCK_SIZE/16][16];
1010
1111FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
12+ uint csel = 0;
1213
1314void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
1415    const uint y_idx = i * QUANT_K + y_offset;
1516
1617    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
1718        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
19+         csel ^= 1;
1820
19-         barrier();
2021        if (!all_threads) { // when we don't have enough blocks to use all threads
2122            if (i < num_blocks_per_row) {
2223                const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
23-                 sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
24-                 sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
24+                 sccache1[csel][ ix][itid] = FLOAT_TYPE(scale & 0xF);
25+                 sccache2[csel][ ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
2526            }
2627            barrier();
2728
2829            if (i >= num_blocks_per_row)
2930                continue;
3031        } else {
3132            const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
32-             sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
33-             sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
33+             sccache1[csel][ ix][itid] = FLOAT_TYPE(scale & 0xF);
34+             sccache2[csel][ ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
3435            barrier();
3536        }
3637
@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
5758            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
5859            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
5960            [[unroll]] for (int l = 0; l < 2; ++l) {
60-                 sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ],
61-                        fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
62-                        fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ],
63-                        fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
64-                        fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ],
65-                        fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
66-                        fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ],
67-                        fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
68-                 sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im],
69-                        fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im],
70-                        fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im],
71-                        fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im],
72-                        fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im],
73-                        fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im],
74-                        fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im],
75-                        fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
61+                 sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ ix][    8*v_im] * qs_u32_0[l  ],
62+                        fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ ix][1 + 8*v_im] * qs_u32_0[l+2],
63+                        fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ ix][2 + 8*v_im] * qs_u32_2[l  ],
64+                        fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ ix][3 + 8*v_im] * qs_u32_2[l+2],
65+                        fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ ix][4 + 8*v_im] * qs_u32_4[l  ],
66+                        fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ ix][5 + 8*v_im] * qs_u32_4[l+2],
67+                        fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ ix][6 + 8*v_im] * qs_u32_6[l  ],
68+                        fma(FLOAT_TYPE(b112[l]), sccache1[csel][ ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
69+                 sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ ix][    8*v_im],
70+                        fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ ix][1 + 8*v_im],
71+                        fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ ix][2 + 8*v_im],
72+                        fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ ix][3 + 8*v_im],
73+                        fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ ix][4 + 8*v_im],
74+                        fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ ix][5 + 8*v_im],
75+                        fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ ix][6 + 8*v_im],
76+                        fma(FLOAT_TYPE(b112[l]), sccache2[csel][ ix][7 + 8*v_im], sum2))))))));
7677            }
7778            temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
7879        }
0 commit comments