@@ -40,8 +40,8 @@ i32vec2 repack(uint ib, uint iqs) {
4040 (vui >> 4 ) & 0x0F0F0F0F);
4141}
4242
43- ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
44- return ACC_TYPE (da * (float (q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
43+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
44+ return FLOAT_TYPE (da * (float (q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
4545}
4646#endif
4747
@@ -53,8 +53,8 @@ i32vec2 repack(uint ib, uint iqs) {
5353 (vui >> 4 ) & 0x0F0F0F0F);
5454}
5555
56- ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
57- return ACC_TYPE (float (q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
56+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
57+ return FLOAT_TYPE (float (q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
5858}
5959#endif
6060
@@ -74,8 +74,8 @@ i32vec2 repack(uint ib, uint iqs) {
7474 return i32vec2(v0, v1);
7575}
7676
77- ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
78- return ACC_TYPE (da * (float (q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
77+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
78+ return FLOAT_TYPE (da * (float (q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
7979}
8080#endif
8181
@@ -95,8 +95,8 @@ i32vec2 repack(uint ib, uint iqs) {
9595 return i32vec2(v0, v1);
9696}
9797
98- ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
99- return ACC_TYPE (float (q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
98+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
99+ return FLOAT_TYPE (float (q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
100100}
101101#endif
102102
@@ -107,8 +107,8 @@ int32_t repack(uint ib, uint iqs) {
107107 data_a_packed16[ib].qs[iqs * 2 + 1 ]));
108108}
109109
110- ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111- return ACC_TYPE (float (q_sum) * da * dsb.x);
110+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111+ return FLOAT_TYPE (float (q_sum) * da * dsb.x);
112112}
113113#endif
114114
@@ -127,8 +127,8 @@ i32vec2 repack(uint ib, uint iqs) {
127127 pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
128128}
129129
130- ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
131- return ACC_TYPE (da * dsb.x * float (q_sum) * 0.5 );
130+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
131+ return FLOAT_TYPE (da * dsb.x * float (q_sum) * 0.5 );
132132}
133133#endif
134134
@@ -157,14 +157,15 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
157157
158158#if defined(DATA_A_Q2_K)
159159// 4-byte loads for Q2_K blocks (84 bytes)
160- int32_t repack (uint ib, uint iqs) {
160+ i32vec2 repack2 (uint ib, uint iqs) {
161161 const uint ib_k = ib / 8 ;
162162 const uint iqs_k = (ib % 8 ) * 8 + iqs;
163163
164164 const uint qs_idx = (iqs_k / 32 ) * 8 + (iqs_k % 8 );
165165 const uint qs_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
166166
167- return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
167+ return i32vec2((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
168+ (data_a_packed32[ib_k].qs[qs_idx + 1 ] >> qs_shift) & 0x03030303);
168169}
169170
170171uint8_t get_scale(uint ib, uint iqs) {
@@ -178,25 +179,24 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
178179 int32_t sum_d = 0 ;
179180 int32_t sum_m = 0 ;
180181
181- const int32_t qs_a0 = repack(ib_a, iqs * 2 );
182- const int32_t qs_a1 = repack(ib_a, iqs * 2 + 1 );
182+ const i32vec2 qs_a = repack2(ib_a, iqs * 2 );
183183 const uint8_t scale = get_scale(ib_a, iqs * 2 );
184+ const vec2 dm = vec2 (get_dm(ib_a));
184185 const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
185186
186- sum_d += dotPacked4x8EXT(qs_a0 , cache_b_qs[0 ]) * (scale & 0xF);
187+ sum_d += dotPacked4x8EXT(qs_a.x , cache_b_qs[0 ]) * (scale & 0xF);
187188 sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0 ]);
188189
189- sum_d += dotPacked4x8EXT(qs_a1 , cache_b_qs[1 ]) * (scale & 0xF);
190+ sum_d += dotPacked4x8EXT(qs_a.y , cache_b_qs[1 ]) * (scale & 0xF);
190191 sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1 ]);
191192
192- const vec2 dm = get_dm(ib_a);
193- return ACC_TYPE(float (cache_b_ds.x) * (float (dm.x) * float (sum_d) - float (dm.y) * float (sum_m) / 4 ));
193+ return FLOAT_TYPE(float (cache_b_ds.x) * (float (dm.x) * float (sum_d) - float (dm.y) * float (sum_m)));
194194}
195195#endif
196196
197197#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
198198// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
199- ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
200- return ACC_TYPE (dsb.x * dma.x * float (q_sum) - dma.y * dsb.y);
199+ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
200+ return FLOAT_TYPE (dsb.x * dma.x * float (q_sum) - dma.y * dsb.y);
201201}
202202#endif
0 commit comments