@@ -233,63 +233,53 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233#ifdef MMQ_SHMEM
234234void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235 const uint ib_k = ib / 8 ;
236- const uint iqs_k = (ib % 8 ) * 8 + iqs;
236+ const uint iqs_k = (ib % 8 ) * 8 + iqs * 4 ;
237237
238238 const uint qs_idx = (iqs_k / 32 ) * 8 + (iqs_k % 8 );
239- // const uint qs_shift = ((iqs_k % 32) / 8) * 2;
239+ const uint qs_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
240240
241241 // Repack 4x4 quants into one int
242- // const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243- // const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
244- // const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
245- // const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
242+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
243+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1 ] >> qs_shift) & 0x03030303;
244+ const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2 ] >> qs_shift) & 0x03030303;
245+ const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3 ] >> qs_shift) & 0x03030303;
246246
247- buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
247+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2 ) | (vals2 << 4 ) | (vals3 << 6 );
248248
249249 if (iqs == 0 ) {
250250 buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
251- buf_a[buf_ib].scales[0 ] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 ]);
252- }
253- if (iqs == 1 ) {
254- buf_a[buf_ib].scales[1 ] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1 ]);
251+ buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8 ]);
255252 }
256253}
257254
258255void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
259256 cache_a[reg_ib].dm = buf_a[buf_ib].dm;
257+ cache_a[reg_ib].scales = buf_a[buf_ib].scales;
260258
261259 [[unroll]] for (uint iqs = 0 ; iqs < 2 ; iqs++ ) {
262- cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
263- }
264-
265- [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
266260 cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
267261 }
268262}
269263
270264ACC_TYPE mmq_dot_product(const uint ib_a) {
271- float sum_d = 0 ;
272- float sum_m = 0 ;
265+ int32_t sum_d = 0 ;
266+ int32_t sum_m = 0 ;
273267
274268 [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
275- const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
276- [[unroll]] for (uint ib_b = 0 ; ib_b < 4 ; ib_b++ ) {
277- const uint8_t scale = cache_a[ib_a].scales[ib_b / 2 ][(ib_b % 2 ) * 2 + (iqs / 4 )];
278- const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
279- const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2 )) & 0x03030303);
280-
281- sum_d += cache_b.ds[ib_b].x * float (dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
282- sum_m += cache_b.ds[ib_b].x * float (dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
283- }
269+ const uint8_t scale = cache_a[ib_a].scales[iqs / 4 ];
270+ const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
271+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4 ] >> ((iqs % 4 ) * 2 )) & 0x03030303);
272+
273+ sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
274+ sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
284275 }
285276
286- return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m );
277+ return mul_q8_1( sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1 );
287278}
288279#endif // MMQ_SHMEM
289280#endif
290281
291282#ifdef MMQ_SHMEM
292- #if defined(DATA_A_QUANT_LEGACY)
293283void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
294284 const uint ib_outer = ib / 4 ;
295285 const uint ib_inner = ib % 4 ;
@@ -311,33 +301,6 @@ void block_b_to_registers(const uint ib) {
311301 cache_b.qs[iqs] = buf_b[ib].qs[iqs];
312302 }
313303}
314- #elif defined(DATA_A_QUANT_K)
315- void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
316- const uint ib_outer = ib / 4 ;
317-
318- buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
319- buf_b[buf_ib].ds[iqs * 2 + 1 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1 ]);
320-
321- [[unroll]] for (uint ib_inner = 0 ; ib_inner < 4 ; ib_inner++ ) {
322- const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
323- buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
324- buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1 ] = values.y;
325- buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2 ] = values.z;
326- buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3 ] = values.w;
327- }
328- }
329-
330- void block_b_to_registers(const uint ib) {
331- [[unroll]] for (uint i = 0 ; i < 4 ; i++ ) {
332- cache_b.ds[i] = buf_b[ib].ds[i];
333- }
334- [[unroll]] for (uint iqs = 0 ; iqs < 32 ; iqs++ ) {
335- cache_b.qs[iqs] = buf_b[ib].qs[iqs];
336- }
337- }
338- #else
339- #error unimplemented
340- #endif
341304#endif
342305
343306#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
0 commit comments