@@ -279,6 +279,73 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
279279#endif // MMQ_SHMEM
280280#endif
281281
282+ #if defined(DATA_A_Q3_K)
283+ // 2-byte loads for Q3_K blocks (110 bytes)
284+ #ifdef MMQ_SHMEM
285+ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
286+ const uint ib_k = ib / 8 ;
287+ const uint hm_idx = iqs * QUANT_R_MMQ;
288+ const uint iqs_k = (ib % 8 ) * 8 + hm_idx;
289+
290+ const uint qs_idx = (iqs_k / 32 ) * 8 + (iqs_k % 8 );
291+ const uint qs_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
292+ const uint hm_shift = iqs_k / 8 ;
293+
294+ // Repack 2x4 quants into one int
295+ // Add the 3rd bit instead of subtracting it to allow packing the quants
296+ const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
297+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
298+ const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
299+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
300+ const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
301+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
302+ const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
303+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3 ] >> hm_shift) & uint16_t(0x0101)) << 2 ));
304+ buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
305+ (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4 );
306+
307+ if (iqs == 0 ) {
308+ const uint is = iqs_k / 4 ;
309+ const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2 ] >> (4 * (is / 8 ))) & 0x0F0F) |
310+ (((data_a_packed16[ib_k].scales[(8 + (is % 4 )) / 2 ] >> (2 * (is / 4 ))) & 0x0303) << 4 )));
311+
312+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32 );
313+ }
314+ }
315+
316+ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
317+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
318+
319+ [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
320+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
321+ }
322+ }
323+
324+ ACC_TYPE mmq_dot_product(const uint ib_a) {
325+ float result = 0.0 ;
326+ int32_t q_sum = 0 ;
327+
328+ [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
329+ // Subtract 4 from the quants to correct the 3rd bit offset
330+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 ));
331+
332+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
333+ }
334+ result += float (cache_a[ib_a].d_scales[0 ]) * float (q_sum);
335+ q_sum = 0 ;
336+
337+ [[unroll]] for (uint iqs = 4 ; iqs < 8 ; iqs++ ) {
338+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 ));
339+
340+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
341+ }
342+ result += float (cache_a[ib_a].d_scales[1 ]) * float (q_sum);
343+
344+ return ACC_TYPE(cache_b.ds.x * result);
345+ }
346+ #endif // MMQ_SHMEM
347+ #endif
348+
282349#if defined(DATA_A_Q4_K)
283350// 4-byte loads for Q4_K blocks (144 bytes)
284351ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
0 commit comments