@@ -59,9 +59,6 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
5959 }
6060}
6161
62- void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) {
63- }
64-
6562ACC_TYPE mmq_dot_product(const uint ib_a) {
6663 int32_t q_sum = 0 ;
6764 [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
@@ -205,6 +202,61 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
205202#endif // MMQ_SHMEM
206203#endif
207204
205+ #if defined(DATA_A_MXFP4)
206+ // 1-byte loads for mxfp4 blocks (17 bytes)
207+ i32vec2 repack(uint ib, uint iqs) {
208+ const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
209+ data_a[ib].qs[iqs * 4 + 1 ],
210+ data_a[ib].qs[iqs * 4 + 2 ],
211+ data_a[ib].qs[iqs * 4 + 3 ]));
212+
213+ return i32vec2( quants & 0x0F0F0F0F,
214+ (quants >> 4 ) & 0x0F0F0F0F);
215+ }
216+
217+ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
218+ return ACC_TYPE(da * dsb.x * float (q_sum));
219+ }
220+
221+ #ifdef MMQ_SHMEM
222+ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
223+ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
224+ data_a[ib].qs[iqs * 4 + 1 ],
225+ data_a[ib].qs[iqs * 4 + 2 ],
226+ data_a[ib].qs[iqs * 4 + 3 ]));
227+
228+ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
229+ const u8vec4 i_a1 = unpack8((qs >> 4 ) & 0x0F0F0F0F);
230+
231+ buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
232+ buf_a[buf_ib].qs[iqs + 4 ] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
233+
234+ if (iqs == 0 ) {
235+ buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5 );
236+ }
237+ }
238+
239+ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
240+ cache_a[reg_ib].d = buf_a[buf_ib].d;
241+
242+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
243+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
244+ }
245+ }
246+
247+ ACC_TYPE mmq_dot_product(const uint ib_a) {
248+ int32_t q_sum = 0 ;
249+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
250+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
251+
252+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
253+ }
254+
255+ return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1 );
256+ }
257+ #endif // MMQ_SHMEM
258+ #endif
259+
208260// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
209261// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
210262#if defined(DATA_A_Q2_K)
0 commit comments