@@ -62,21 +62,25 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
6262void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) {
6363}
6464
65- ACC_TYPE mmq_dot_product(const uint ib_a) {
66- int32_t q_sum = 0 ;
65+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
66+ i32vec2 q_sum = i32vec2( 0 ) ;
6767 [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
68- const uint32_t vui = cache_a[ib_a].qs[iqs];
69- const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
70- (vui >> 4 ) & 0x0F0F0F0F);
68+ const u32vec2 vui = u32vec2(cache_a[ib_a ].qs[iqs],
69+ cache_a[ib_a + 1 ].qs[iqs]);
70+ const i32vec4 qs_a = i32vec4(vui.x & 0x0F0F0F0F, (vui.x >> 4 ) & 0x0F0F0F0F,
71+ vui.y & 0x0F0F0F0F, (vui.y >> 4 ) & 0x0F0F0F0F);
72+ const i32vec2 qs_b = i32vec2(cache_b.qs[iqs],
73+ cache_b.qs[iqs + 4 ]);
74+
75+ q_sum.x += dotPacked4x8EXT(qs_a.x, qs_b.x);
76+ q_sum.y += dotPacked4x8EXT(qs_a.z, qs_b.x);
77+ q_sum.x += dotPacked4x8EXT(qs_a.y, qs_b.y);
78+ q_sum.y += dotPacked4x8EXT(qs_a.w, qs_b.y);
7179
72- const int32_t qs_b0 = cache_b.qs[iqs];
73- const int32_t qs_b1 = cache_b.qs[iqs + 4 ];
74-
75- q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
76- q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
7780 }
7881
79- return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1 );
82+ return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1 ),
83+ mul_q8_1(q_sum.y, cache_a[ib_a + 1 ].dm, cache_b.ds, 1 ));
8084}
8185#endif // MMQ_SHMEM
8286
@@ -140,24 +144,35 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
140144 }
141145}
142146
143- ACC_TYPE mmq_dot_product(const uint ib_a) {
144- int32_t q_sum = 0 ;
147+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
148+ i32vec2 q_sum = i32vec2( 0 ) ;
145149 [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
146- const uint32_t vui = cache_a[ib_a].qs[iqs];
147- const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
148- const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
149- | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
150- const int32_t qs_a1 = int32_t((vui >> 4 ) & 0x0F0F0F0F)
151- | (((qh >> 16 ) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
152-
153- const int32_t qs_b0 = cache_b.qs[iqs];
154- const int32_t qs_b1 = cache_b.qs[iqs + 4 ];
155-
156- q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
157- q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
150+ const i32vec2 qs_b = i32vec2(cache_b.qs[iqs ],
151+ cache_b.qs[iqs + 4 ]);
152+
153+ const u32vec2 vui = u32vec2(cache_a[ib_a ].qs[iqs],
154+ cache_a[ib_a + 1 ].qs[iqs]);
155+
156+ const int32_t qh_0 = int32_t(cache_a[ib_a ].qh >> (4 * iqs));
157+ const int32_t qh_1 = int32_t(cache_a[ib_a + 1 ].qh >> (4 * iqs));
158+
159+ const i32vec2 qs_a0 = i32vec2(int32_t(vui.x & 0x0F0F0F0F) | ((qh_0 & 0xF) * 0x02040810) & 0x10101010, // (0,1,2,3) -> (4,12,20,28)
160+ int32_t((vui.x >> 4 ) & 0x0F0F0F0F) | (((qh_0 >> 16 ) & 0xF) * 0x02040810) & 0x10101010); // (16,17,18,19) -> (4,12,20,28)
161+
162+
163+ const i32vec2 qs_a1 = i32vec2(int32_t(vui.y & 0x0F0F0F0F) | ((qh_1 & 0xF) * 0x02040810) & 0x10101010,
164+ int32_t((vui.y >> 4 ) & 0x0F0F0F0F) | (((qh_1 >> 16 ) & 0xF) * 0x02040810) & 0x10101010);
165+
166+ q_sum.x += dotPacked4x8EXT(qs_a0.x, qs_b.x);
167+ q_sum.y += dotPacked4x8EXT(qs_a1.x, qs_b.x);
168+
169+ q_sum.x += dotPacked4x8EXT(qs_a0.y, qs_b.y);
170+ q_sum.y += dotPacked4x8EXT(qs_a1.y, qs_b.y);
171+
158172 }
159173
160- return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1 );
174+ return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1 ),
175+ mul_q8_1(q_sum.y, cache_a[ib_a + 1 ].dm, cache_b.ds, 1 ));
161176}
162177#endif // MMQ_SHMEM
163178#endif
@@ -191,16 +206,16 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
191206 }
192207}
193208
194- ACC_TYPE mmq_dot_product(const uint ib_a) {
195- int32_t q_sum = 0 ;
209+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
210+ i32vec2 q_sum = i32vec2( 0 ) ;
196211 [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
197- const int32_t qs_a = cache_a[ib_a].qs[iqs];
198212 const int32_t qs_b = cache_b.qs[iqs];
199-
200- q_sum += dotPacked4x8EXT(qs_a , qs_b);
213+ q_sum.x += dotPacked4x8EXT(cache_a[ib_a ].qs[iqs], qs_b);
214+ q_sum.y += dotPacked4x8EXT(cache_a[ib_a + 1 ].qs[iqs] , qs_b);
201215 }
202216
203- return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1 );
217+ return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1 ),
218+ mul_q8_1(q_sum.y, cache_a[ib_a + 1 ].dm, cache_b.ds, 1 ));
204219}
205220#endif // MMQ_SHMEM
206221#endif
@@ -261,21 +276,34 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
261276 }
262277}
263278
264- ACC_TYPE mmq_dot_product(const uint ib_a) {
265- int32_t sum_d = 0 ;
266- int32_t sum_m = 0 ;
279+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
280+ i32vec2 sum_d = i32vec2( 0 ) ;
281+ i32vec2 sum_m = i32vec2( 0 ) ;
267282
268283 [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
269- const uint8_t scale = cache_a[ib_a].scales[iqs / 4 ];
270- const int32_t scale_m = int32_t(scale >> 4 ) * 0x01010101; // Duplicate 8-bit value across 32-bits.
271- const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4 ] >> ((iqs % 4 ) * 2 )) & 0x03030303);
284+ const u8vec2 scale = u8vec2(cache_a[ib_a ].scales[iqs / 4 ],
285+ cache_a[ib_a + 1 ].scales[iqs / 4 ]);
286+
287+ const i32vec2 scale_m = i32vec2(int32_t(scale.x >> 4 ) * 0x01010101,
288+ int32_t(scale.y >> 4 ) * 0x01010101); // Duplicate 8-bit value across 32-bits.
289+
290+ const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 4 ] >> ((iqs % 4 ) * 2 )) & 0x03030303,
291+ (cache_a[ib_a + 1 ].qs[iqs / 4 ] >> ((iqs % 4 ) * 2 )) & 0x03030303);
292+
293+ const int32_t qs_b = cache_b.qs[iqs];
294+ sum_d.x += dotPacked4x8EXT(qs_a.x, qs_b) * (scale.x & 0xF);
295+ sum_d.y += dotPacked4x8EXT(qs_a.y, qs_b) * (scale.y & 0xF);
296+
297+ sum_m.x += dotPacked4x8EXT(scale_m.x, qs_b);
298+ sum_m.y += dotPacked4x8EXT(scale_m.y, qs_b);
272299
273- sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
274- sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
275300 }
276301
277- return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1 );
302+ return ACC_TYPE_VEC2(mul_q8_1(sum_d.x, sum_m.x, cache_a[ib_a ].dm, cache_b.ds, 1 ),
303+ mul_q8_1(sum_d.y, sum_m.y, cache_a[ib_a + 1 ].dm, cache_b.ds, 1 ));
304+
278305}
306+
279307#endif // MMQ_SHMEM
280308#endif
281309
@@ -321,27 +349,34 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
321349 }
322350}
323351
324- ACC_TYPE mmq_dot_product(const uint ib_a) {
325- float result = 0.0 ;
326- int32_t q_sum = 0 ;
352+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
353+ vec2 result = vec2 ( 0.0 ) ;
354+ i32vec2 q_sum = i32vec2( 0 ) ;
327355
328356 [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
329357 // Subtract 4 from the quants to correct the 3rd bit offset
330- const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 ));
358+ const i32vec2 qs_a = i32vec2(pack32(unpack8(int32_t((cache_a[ib_a ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 )),
359+ pack32(unpack8(int32_t((cache_a[ib_a + 1 ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 )));
331360
332- q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
361+ q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
362+ q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
333363 }
334- result += float (cache_a[ib_a].d_scales[0 ]) * float (q_sum);
335- q_sum = 0 ;
364+ result.x += float (cache_a[ib_a ].d_scales[0 ]) * float (q_sum.x);
365+ result.y += float (cache_a[ib_a + 1 ].d_scales[0 ]) * float (q_sum.y);
366+ q_sum = i32vec2(0 );
336367
337368 [[unroll]] for (uint iqs = 4 ; iqs < 8 ; iqs++ ) {
338- const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 ));
369+ const i32vec2 qs_a = i32vec2(pack32(unpack8(int32_t((cache_a[ib_a ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 )),
370+ pack32(unpack8(int32_t((cache_a[ib_a + 1 ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F)) - int8_t(4 )));
339371
340- q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
372+ q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
373+ q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
341374 }
342- result += float (cache_a[ib_a].d_scales[1 ]) * float (q_sum);
375+ result.x += float (cache_a[ib_a ].d_scales[1 ]) * float (q_sum.x);
376+ result.y += float (cache_a[ib_a + 1 ].d_scales[1 ]) * float (q_sum.y);
343377
344- return ACC_TYPE(cache_b.ds.x * result);
378+ return ACC_TYPE_VEC2(cache_b.ds.x * result.x,
379+ cache_b.ds.x * result.y);
345380}
346381#endif // MMQ_SHMEM
347382#endif
@@ -398,20 +433,24 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
398433 }
399434}
400435
401- ACC_TYPE mmq_dot_product(const uint ib_a) {
402- int32_t q_sum = 0 ;
436+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
437+ i32vec2 q_sum = i32vec2( 0 ) ;
403438
404439 [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
405440#if defined(DATA_A_Q4_K)
406- const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F);
441+ const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F,
442+ (cache_a[ib_a + 1 ].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F);
407443#else // defined(DATA_A_Q5_K)
408- const int32_t qs_a = cache_a[ib_a].qs[iqs];
444+ const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
445+ cache_a[ib_a + 1 ].qs[iqs]);
409446#endif
410447
411- q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
448+ q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
449+ q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
412450 }
413451
414- return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1 );
452+ return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1 ),
453+ mul_q8_1(q_sum.y, cache_a[ib_a + 1 ].dm, cache_b.ds, 1 ));
415454}
416455#endif // MMQ_SHMEM
417456#endif
@@ -475,26 +514,33 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
475514 }
476515}
477516
478- ACC_TYPE mmq_dot_product(const uint ib_a) {
479- float result = 0.0 ;
480- int32_t q_sum = 0 ;
517+ ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
518+ vec2 result = vec2 ( 0.0 ) ;
519+ i32vec2 q_sum = i32vec2( 0 ) ;
481520
482521 [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
483- const int32_t qs_a = cache_a[ib_a].qs[iqs];
522+ const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
523+ cache_a[ib_a + 1 ].qs[iqs]);
484524
485- q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
525+ q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
526+ q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
486527 }
487- result += float (cache_a[ib_a].d_scales[0 ]) * float (q_sum);
488- q_sum = 0 ;
528+ result.x += float (cache_a[ib_a ].d_scales[0 ]) * float (q_sum.x);
529+ result.y += float (cache_a[ib_a + 1 ].d_scales[0 ]) * float (q_sum.y);
530+ q_sum = i32vec2(0 );
489531
490532 [[unroll]] for (uint iqs = 4 ; iqs < 8 ; iqs++ ) {
491- const int32_t qs_a = cache_a[ib_a].qs[iqs];
533+ const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
534+ cache_a[ib_a + 1 ].qs[iqs]);
492535
493- q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
536+ q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
537+ q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
494538 }
495- result += float (cache_a[ib_a].d_scales[1 ]) * float (q_sum);
539+ result.x += float (cache_a[ib_a ].d_scales[1 ]) * float (q_sum.x);
540+ result.y += float (cache_a[ib_a + 1 ].d_scales[1 ]) * float (q_sum.y);
496541
497- return ACC_TYPE(cache_b.ds.x * result);
542+ return ACC_TYPE_VEC2(cache_b.ds.x * result.x,
543+ cache_b.ds.x * result.y);
498544}
499545#endif // MMQ_SHMEM
500546#endif
0 commit comments