@@ -10,18 +10,18 @@ FLOAT_TYPE get_dm(uint ib) {
1010}
1111#endif
1212
13- #if defined(DATA_A_MXFP4)
14- FLOAT_TYPE get_dm(uint ib) {
15- return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
16- }
17- #endif
18-
1913#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
2014FLOAT_TYPE_VEC2 get_dm(uint ib) {
2115 return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
2216}
2317#endif
2418
19+ #if defined(DATA_A_MXFP4)
20+ FLOAT_TYPE get_dm(uint ib) {
21+ return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
22+ }
23+ #endif
24+
2525#if defined(DATA_A_Q2_K)
2626FLOAT_TYPE_VEC2 get_dm(uint ib) {
2727 const uint ib_k = ib / 8 ;
@@ -115,22 +115,25 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int
115115#if defined(DATA_A_MXFP4)
116116// 1-byte loads for mxfp4 blocks (17 bytes)
117117i32vec2 repack(uint ib, uint iqs) {
118- const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
119- data_a[ib].qs[iqs * 4 + 1 ],
120- data_a[ib].qs[iqs * 4 + 2 ],
121- data_a[ib].qs[iqs * 4 + 3 ]));
118+ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
119+ data_a[ib].qs[iqs * 4 + 1 ],
120+ data_a[ib].qs[iqs * 4 + 2 ],
121+ data_a[ib].qs[iqs * 4 + 3 ]));
122+
123+ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
124+ const u8vec4 i_a1 = unpack8((qs >> 4 ) & 0x0F0F0F0F);
122125
123- return i32vec2( quants & 0x0F0F0F0F ,
124- (quants >> 4 ) & 0x0F0F0F0F );
126+ return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])) ,
127+ pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])) );
125128}
126129
127130ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
128- return ACC_TYPE(da * dsb.x * float (q_sum));
131+ return ACC_TYPE(da * dsb.x * float (q_sum) * 0.5 );
129132}
130133#endif
131134
132135#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
133- FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_divisor ) {
136+ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
134137 int32_t q_sum = 0 ;
135138#if QUANT_R == 2
136139 const i32vec2 data_a_qs = repack(ib_a, iqs);
@@ -147,7 +150,8 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_d
147150 cache_b_qs[1 ]);
148151#endif
149152
150- return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, sum_divisor);
153+ // 2 quants per call => divide sums by 8/2 = 4
154+ return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4 );
151155}
152156#endif
153157
@@ -170,8 +174,23 @@ uint8_t get_scale(uint ib, uint iqs) {
170174 return data_a[ib_k].scales[iqs_k / 4 ];
171175}
172176
173- ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
174- return ACC_TYPE(dsb.x * (dma.x * float (sum_d) - dma.y * float (sum_m)));
177+ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
178+ int32_t sum_d = 0 ;
179+ int32_t sum_m = 0 ;
180+
181+ const int32_t qs_a0 = repack(ib_a, iqs * 2 );
182+ const int32_t qs_a1 = repack(ib_a, iqs * 2 + 1 );
183+ const uint8_t scale = get_scale(ib_a, iqs * 2 );
184+ const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
185+
186+ sum_d += dotPacked4x8EXT(qs_a0, cache_b_qs[0 ]) * (scale & 0xF);
187+ sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0 ]);
188+
189+ sum_d += dotPacked4x8EXT(qs_a1, cache_b_qs[1 ]) * (scale & 0xF);
190+ sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1 ]);
191+
192+ const vec2 dm = get_dm(ib_a);
193+ return ACC_TYPE(float (cache_b_ds.x) * (float (dm.x) * float (sum_d) - float (dm.y) * float (sum_m) / 4 ));
175194}
176195#endif
177196
0 commit comments