@@ -244,17 +244,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
244244 const uint iqs = idx % 128 ; // 0..127
245245
246246 const uint n = iqs / 64 ; // 0,1
247- const uint b = (iqs % 64 ) / 32 ; // 0,1
247+ const uint b = (( iqs % 64 ) / 32 ) * 4 ; // 0,4
248248 const uint is_b = (iqs % 16 ) / 8 ; // 0,1
249249 const uint qhshift = ((iqs % 64 ) / 16 ) * 2 ; // 0,2,4,6
250250 const uint is = 8 * n + qhshift + is_b; // 0..15
251- const uint qsi = n * 64 + (iqs % 32 ) * 2 ; // 0,2,4..126
252- const uint qhi = n * 32 + (iqs % 16 ) * 2 ; // 0,2,4..62
251+ const uint qsi = n * 32 + (iqs % 32 ); // 0..63
252+ const uint qhi = n * 16 + (iqs % 16 ); // 0..31
253253
254254 const float dscale = float (data_a[ib].d) * float (data_a[ib].scales[is]);
255255
256- buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float (int8_t(((data_a[ib].ql[qsi ] >> (b * 4 )) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3 ) << 4 )) - 32 ),
257- dscale * float (int8_t(((data_a[ib].ql[qsi + 1 ] >> (b * 4 )) & 0xF) | (((data_a[ib].qh[qhi + 1 ] >> qhshift) & 3 ) << 4 )) - 32 ));
256+ const uint ql = (uint (data_a_packed16[ib].ql[qsi]) >> b) & 0x0F0F;
257+ const uint qh = (uint (data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;
258+ const vec2 q = (vec2 (unpack8(ql | (qh << 4 )).xy) - 32 ) * dscale;
259+
260+ buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);
258261#elif defined(DATA_A_IQ1_S)
259262 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
260263 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
0 commit comments