@@ -9,8 +9,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
99
1010void main() {
1111    [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12-         const uint i  = gl_WorkGroupID.x * 256 + wgy;
13-         if (i  >= p.M * p.K / QUANT_K) {
12+         const uint ib  = gl_WorkGroupID.x * 256 + wgy;
13+         if (ib  >= p.M * p.K / QUANT_K) {
1414            return;
1515        }
1616
@@ -19,40 +19,52 @@ void main() {
1919        const uint ir = tid % 16;
2020        const uint is = 2 * il;
2121
22-         const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i ].d.x);
23-         const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i ].d.y);
22+         const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib ].d.x);
23+         const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib ].d.y);
2424
25-         const uint y_idx = i  * QUANT_K + 64 * il + 2 * ir;
25+         const uint y_idx = ib  * QUANT_K + 64 * il + 2 * ir;
2626        const uint qs_idx = 32*il + 2 * ir;
2727        const uint qh_idx = 2 * ir;
2828
29-         uint8_t sc;
30-         uint8_t m;
31-         if (is < 4) {
32-             sc = uint8_t(data_a[i].scales[is] & 63);
33-             m  = uint8_t(data_a[i].scales[is + 4] & 63);
34-         } else {
35-             sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
36-             m  = uint8_t((data_a[i].scales[is + 4] >>  4) | ((data_a[i].scales[is    ] >> 6) << 4));
37-         }
29+         uint scidx0 = (is < 4) ? is : (is + 4);
30+         uint scidx1 = (is < 4) ? is : (is - 4);
31+         uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
32+         uint scidxshift1 = (is < 4) ? 0 : 2;
33+         uint mbidx0 = is + 4;
34+         uint mbidx1 = (is < 4) ? is + 4 : is;
35+         uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
36+         uint mbidxshift0 = (is < 4) ? 0 : 4;
37+         uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
38+         uint mbidxshift1 = (is < 4) ? 0 : 2;
39+ 
40+         uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
41+         uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
42+ 
3843        const FLOAT_TYPE d1 = dall * sc;
39-         const FLOAT_TYPE m1 = dmin * m;
40- 
41-         if (is < 4) {
42-             sc = uint8_t(data_a[i].scales[is + 1] & 63);
43-             m  = uint8_t(data_a[i].scales[is + 5] & 63);
44-         } else {
45-             sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
46-             m  = uint8_t((data_a[i].scales[is + 5] >>  4) | ((data_a[i].scales[is + 1] >> 6) << 4));
47-         }
44+         const FLOAT_TYPE m1 = dmin * mbyte;
45+ 
46+         scidx0 = (is < 4) ? is + 1 : (is + 5);
47+         scidx1 = (is < 4) ? is + 1 : (is - 3);
48+         scidxmask1 = (is < 4) ? 0x30 : 0xC0;
49+         scidxshift1 = (is < 4) ? 0 : 2;
50+         mbidx0 = is + 5;
51+         mbidx1 = (is < 4) ? is + 5 : is + 1;
52+         mbidxmask0 = (is < 4) ? 0xF : 0xF0;
53+         mbidxshift0 = (is < 4) ? 0 : 4;
54+         mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
55+         mbidxshift1 = (is < 4) ? 0 : 2;
56+ 
57+         sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
58+         mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
59+ 
4860        const FLOAT_TYPE d2 = dall * sc;
49-         const FLOAT_TYPE m2 = dmin * m ;
61+         const FLOAT_TYPE m2 = dmin * mbyte ;
5062
5163        const uint8_t hm1 = uint8_t(1 << (2 * il    ));
5264        const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
53-         data_b[y_idx     ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i ].qs[qs_idx    ] & 0xF) + (((data_a[i ].qh[qh_idx    ] & hm1) != 0) ? 16 : 0)) - m1);
54-         data_b[y_idx +  1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i ].qs[qs_idx + 1] & 0xF) + (((data_a[i ].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
55-         data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i ].qs[qs_idx    ]  >> 4) + (((data_a[i ].qh[qh_idx    ] & hm2) != 0) ? 16 : 0)) - m2);
56-         data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i ].qs[qs_idx + 1]  >> 4) + (((data_a[i ].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
65+         data_b[y_idx     ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib ].qs[qs_idx    ] & 0xF) + (((data_a[ib ].qh[qh_idx    ] & hm1) != 0) ? 16 : 0)) - m1);
66+         data_b[y_idx +  1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib ].qs[qs_idx + 1] & 0xF) + (((data_a[ib ].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
67+         data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib ].qs[qs_idx    ]  >> 4) + (((data_a[ib ].qh[qh_idx    ] & hm2) != 0) ? 16 : 0)) - m2);
68+         data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib ].qs[qs_idx + 1]  >> 4) + (((data_a[ib ].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
5769    }
5870}
0 commit comments