@@ -233,7 +233,7 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
233233#ifdef MMQ_SHMEM
234234void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
235235 const uint ib_k = ib / 8 ;
236- const uint iqs_k = (ib % 8 ) * 8 + iqs * 4 ;
236+ const uint iqs_k = (ib % 8 ) * 8 + iqs * QUANT_R_MMQ ;
237237
238238 const uint qs_idx = (iqs_k / 32 ) * 8 + (iqs_k % 8 );
239239 const uint qs_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
@@ -279,6 +279,63 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
279279#endif // MMQ_SHMEM
280280#endif
281281
282+ #if defined(DATA_A_Q4_K)
283+ // 4-byte loads for Q4_K blocks (144 bytes)
284+ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
285+ return ACC_TYPE(dsb.x * dma.x * float (q_sum) - dma.y * dsb.y);
286+ }
287+
288+ #ifdef MMQ_SHMEM
289+ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
290+ const uint ib_k = ib / 8 ;
291+ const uint iqs_k = (ib % 8 ) * 8 + iqs * QUANT_R_MMQ;
292+
293+ const uint qs_idx = (iqs_k / 16 ) * 8 + (iqs_k % 8 );
294+ const uint qs_shift = ((iqs_k % 16 ) / 8 ) * 4 ;
295+
296+ // Repack 2x4 quants into one int
297+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
298+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1 ] >> qs_shift) & 0x0F0F0F0F;
299+
300+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4 );
301+
302+ if (iqs == 0 ) {
303+ // Scale index
304+ const uint is = iqs_k / 8 ;
305+ u8vec2 scale_dm;
306+ if (is < 4 ) {
307+ scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4 ] & 0x3F);
308+ } else {
309+ scale_dm = u8vec2((data_a[ib_k].scales[is+ 4 ] & 0xF) | ((data_a[ib_k].scales[is- 4 ] & 0xC0) >> 2 ),
310+ (data_a[ib_k].scales[is+ 4 ] >> 4 ) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2 ));
311+ }
312+
313+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
314+ }
315+ }
316+
317+ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
318+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
319+
320+ [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
321+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
322+ }
323+ }
324+
325+ ACC_TYPE mmq_dot_product(const uint ib_a) {
326+ int32_t q_sum = 0 ;
327+
328+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
329+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F);
330+
331+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
332+ }
333+
334+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1 );
335+ }
336+ #endif // MMQ_SHMEM
337+ #endif
338+
282339#ifdef MMQ_SHMEM
283340void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
284341 const uint ib_outer = ib / 4 ;
0 commit comments