@@ -233,69 +233,111 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233#ifdef MMQ_SHMEM
234234void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235 const uint ib_k = ib / 8 ;
236- const uint iqs_k = (ib % 8 ) * 8 + iqs * 4 ;
236+ const uint iqs_k = (ib % 8 ) * 8 + iqs;
237237
238238 const uint qs_idx = (iqs_k / 32 ) * 8 + (iqs_k % 8 );
239- const uint qs_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
239+ // const uint qs_shift = ((iqs_k % 32) / 8) * 2;
240240
241241 // Repack 4x4 quants into one int
242- const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243- const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1 ] >> qs_shift) & 0x03030303;
244- const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2 ] >> qs_shift) & 0x03030303;
245- const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3 ] >> qs_shift) & 0x03030303;
242+ // const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243+ // const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244+ // const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245+ // const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
246246
247- buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2 ) | (vals2 << 4 ) | (vals3 << 6 );
247+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
248248
249249 if (iqs == 0 ) {
250- buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8 ]);
251250 buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
251+ buf_a[buf_ib].scales[0 ] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 ]);
252+ }
253+ if (iqs == 1 ) {
254+ buf_a[buf_ib].scales[1 ] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1 ]);
252255 }
253256}
254257
255258void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
256259 cache_a[reg_ib].dm = buf_a[buf_ib].dm;
257- cache_a[reg_ib].scales = buf_a[buf_ib].scales;
258260
259261 [[unroll]] for (uint iqs = 0 ; iqs < 2 ; iqs++ ) {
262+ cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
263+ }
264+
265+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
260266 cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
261267 }
262268}
263269
264270ACC_TYPE mmq_dot_product(const uint ib_a) {
265- int32_t sum_d = 0 ;
266- int32_t sum_m = 0 ;
271+ float sum_d = 0 ;
272+ float sum_m = 0 ;
267273
268- uint8_t scale = cache_a[ib_a].scales[0 ];
269- int32_t scale_m = int32_t(scale >> 4 );
270- scale_m |= scale_m << 8 ;
271- scale_m |= scale_m << 16 ;
274+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
275+ const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
276+ [[unroll]] for (uint ib_b = 0 ; ib_b < 4 ; ib_b++ ) {
277+ const uint8_t scale = cache_a[ib_a].scales[ib_b / 2 ][(ib_b % 2 ) * 2 + (iqs / 4 )];
278+ const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
279+ const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2 )) & 0x03030303);
280+
281+ sum_d += cache_b.ds[ib_b].x * float (dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
282+ sum_m += cache_b.ds[ib_b].x * float (dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
283+ }
284+ }
272285
273- [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
274- const uint qs_shift = iqs * 2 ;
286+ return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m);
287+ }
288+ #endif // MMQ_SHMEM
289+ #endif
275290
276- const int32_t qs_a = int32_t((cache_a[ib_a].qs[0 ] >> qs_shift) & 0x03030303);
291+ #ifdef MMQ_SHMEM
292+ #if defined(DATA_A_QUANT_LEGACY)
293+ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
294+ const uint ib_outer = ib / 4 ;
295+ const uint ib_inner = ib % 4 ;
277296
278- sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[ iqs]) * (scale & 0xF);
279- sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs ]);
297+ if ( iqs == 0 ) {
298+ buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner ]);
280299 }
281300
282- scale = cache_a[ib_a].scales[1 ];
283- scale_m = int32_t(scale >> 4 );
284- scale_m |= scale_m << 8 ;
285- scale_m |= scale_m << 16 ;
286-
287- [[unroll]] for (uint iqs = 4 ; iqs < 8 ; iqs++ ) {
288- const uint qs_shift = (iqs - 4 ) * 2 ;
289-
290- const int32_t qs_a = int32_t((cache_a[ib_a].qs[1 ] >> qs_shift) & 0x03030303);
301+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
302+ buf_b[buf_ib].qs[iqs * 4 ] = values.x;
303+ buf_b[buf_ib].qs[iqs * 4 + 1 ] = values.y;
304+ buf_b[buf_ib].qs[iqs * 4 + 2 ] = values.z;
305+ buf_b[buf_ib].qs[iqs * 4 + 3 ] = values.w;
306+ }
291307
292- sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
293- sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
308+ void block_b_to_registers(const uint ib) {
309+ cache_b.ds = buf_b[ib].ds;
310+ [[unroll]] for (uint iqs = 0 ; iqs < BK / 4 ; iqs++ ) {
311+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
312+ }
313+ }
314+ #elif defined(DATA_A_QUANT_K)
315+ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
316+ const uint ib_outer = ib / 4 ;
317+
318+ buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
319+ buf_b[buf_ib].ds[iqs * 2 + 1 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1 ]);
320+
321+ [[unroll]] for (uint ib_inner = 0 ; ib_inner < 4 ; ib_inner++ ) {
322+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
323+ buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
324+ buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1 ] = values.y;
325+ buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2 ] = values.z;
326+ buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3 ] = values.w;
294327 }
328+ }
295329
296- return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1 );
330+ void block_b_to_registers(const uint ib) {
331+ [[unroll]] for (uint i = 0 ; i < 4 ; i++ ) {
332+ cache_b.ds[i] = buf_b[ib].ds[i];
333+ }
334+ [[unroll]] for (uint iqs = 0 ; iqs < 32 ; iqs++ ) {
335+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
336+ }
297337}
298- #endif // MMQ_SHMEM
338+ #else
339+ #error unimplemented
340+ #endif
299341#endif
300342
301343#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
0 commit comments