@@ -243,36 +243,70 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict
243243 return r ;
244244}
245245
246- static inline HVX_Vector hvx_vec_rmpy_x8 (HVX_Vector_x8 x , HVX_Vector_x8 y ) {
247- HVX_Vector r0 = Q6_Vw_vrmpy_VbVb (x .v [0 ], y .v [0 ]);
248- HVX_Vector r1 = Q6_Vw_vrmpy_VbVb (x .v [1 ], y .v [1 ]);
249- HVX_Vector r2 = Q6_Vw_vrmpy_VbVb (x .v [2 ], y .v [2 ]);
250- HVX_Vector r3 = Q6_Vw_vrmpy_VbVb (x .v [3 ], y .v [3 ]);
251- HVX_Vector r4 = Q6_Vw_vrmpy_VbVb (x .v [4 ], y .v [4 ]);
252- HVX_Vector r5 = Q6_Vw_vrmpy_VbVb (x .v [5 ], y .v [5 ]);
253- HVX_Vector r6 = Q6_Vw_vrmpy_VbVb (x .v [6 ], y .v [6 ]);
254- HVX_Vector r7 = Q6_Vw_vrmpy_VbVb (x .v [7 ], y .v [7 ]);
255-
256- HVX_VectorPair p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 );
257- HVX_VectorPair p1 = Q6_W_vdeal_VVR (r3 , r2 , -4 );
258- HVX_VectorPair p2 = Q6_W_vdeal_VVR (r5 , r4 , -4 );
259- HVX_VectorPair p3 = Q6_W_vdeal_VVR (r7 , r6 , -4 );
260- r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 ));
261- r1 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p1 ), Q6_V_hi_W (p1 ));
262- r2 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p2 ), Q6_V_hi_W (p2 ));
263- r3 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p3 ), Q6_V_hi_W (p3 ));
264-
265- p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 );
266- p1 = Q6_W_vdeal_VVR (r3 , r2 , -4 );
267- r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 ));
268- r1 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p1 ), Q6_V_hi_W (p1 ));
269-
270- p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 );
271- r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 ));
246+ // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
247+ // Accumulate each block into a single int32 value.
248+ // Return a single HVX vector with 32x int32 accumulators.
249+ // This version is parameterized to support less than 1024 elements.
250+ // if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
251+
252+ static inline HVX_Vector hvx_vec_rmpy_x8_n (HVX_Vector_x8 x , HVX_Vector_x8 y , unsigned int n ) {
253+ HVX_Vector r0 = Q6_V_vsplat_R (0 );
254+ HVX_Vector r1 = Q6_V_vsplat_R (0 );
255+ HVX_Vector r2 = Q6_V_vsplat_R (0 );
256+ HVX_Vector r3 = Q6_V_vsplat_R (0 );
257+ HVX_Vector r4 = Q6_V_vsplat_R (0 );
258+ HVX_Vector r5 = Q6_V_vsplat_R (0 );
259+ HVX_Vector r6 = Q6_V_vsplat_R (0 );
260+ HVX_Vector r7 = Q6_V_vsplat_R (0 );
261+
262+ HVX_VectorPair p3 ;
263+ HVX_VectorPair p2 ;
264+ HVX_VectorPair p1 ;
265+ HVX_VectorPair p0 ;
266+
267+ if (n >= 128 ) { r0 = Q6_Vw_vrmpy_VbVb (x .v [0 ], y .v [0 ]); }
268+ if (n >= 256 ) { r1 = Q6_Vw_vrmpy_VbVb (x .v [1 ], y .v [1 ]); }
269+ if (n >= 384 ) { r2 = Q6_Vw_vrmpy_VbVb (x .v [2 ], y .v [2 ]); }
270+ if (n >= 512 ) { r3 = Q6_Vw_vrmpy_VbVb (x .v [3 ], y .v [3 ]); }
271+ if (n >= 640 ) { r4 = Q6_Vw_vrmpy_VbVb (x .v [4 ], y .v [4 ]); }
272+ if (n >= 768 ) { r5 = Q6_Vw_vrmpy_VbVb (x .v [5 ], y .v [5 ]); }
273+ if (n >= 896 ) { r6 = Q6_Vw_vrmpy_VbVb (x .v [6 ], y .v [6 ]); }
274+ if (n >= 1024 ) { r7 = Q6_Vw_vrmpy_VbVb (x .v [7 ], y .v [7 ]); }
275+
276+ if (n >= 128 ) { p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 ); }
277+ if (n >= 384 ) { p1 = Q6_W_vdeal_VVR (r3 , r2 , -4 ); }
278+ if (n >= 640 ) { p2 = Q6_W_vdeal_VVR (r5 , r4 , -4 ); }
279+ if (n >= 896 ) { p3 = Q6_W_vdeal_VVR (r7 , r6 , -4 ); }
280+
281+ if (n >= 128 ) { r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 )); }
282+ if (n >= 384 ) { r1 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p1 ), Q6_V_hi_W (p1 )); }
283+ if (n >= 640 ) { r2 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p2 ), Q6_V_hi_W (p2 )); }
284+ if (n >= 896 ) { r3 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p3 ), Q6_V_hi_W (p3 )); }
285+
286+ if (n >= 128 ) { p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 ); }
287+ if (n >= 640 ) { p1 = Q6_W_vdeal_VVR (r3 , r2 , -4 ); }
288+
289+ if (n >= 128 ) { r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 )); }
290+ if (n >= 640 ) { r1 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p1 ), Q6_V_hi_W (p1 )); }
291+
292+ if (n >= 128 ) { p0 = Q6_W_vdeal_VVR (r1 , r0 , -4 ); }
293+ if (n >= 128 ) { r0 = Q6_Vw_vadd_VwVw (Q6_V_lo_W (p0 ), Q6_V_hi_W (p0 )); }
272294
273295 return r0 ;
274296}
275297
298+ static inline HVX_Vector hvx_vec_rmpy_x8_full (HVX_Vector_x8 x , HVX_Vector_x8 y ) {
299+ return hvx_vec_rmpy_x8_n (x , y , 1024 );
300+ }
301+
302+ // Handle most common cases of tensors not multiple of 1024.
303+ static inline HVX_Vector hvx_vec_rmpy_x8_nloe (HVX_Vector_x8 x , HVX_Vector_x8 y , unsigned int n ) {
304+ if (n <= 256 ) { return hvx_vec_rmpy_x8_n (x , y , 256 ); };
305+ if (n <= 512 ) { return hvx_vec_rmpy_x8_n (x , y , 512 ); };
306+ if (n <= 768 ) { return hvx_vec_rmpy_x8_n (x , y , 768 ); };
307+ return hvx_vec_rmpy_x8_n (x , y , 1024 );
308+ }
309+
276310static void vec_dot_q4x4x2_q8x4x2 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
277311 assert (n % 32 == 0 ); // min sub-block size
278312 assert ((unsigned long ) vx % 128 == 0 );
@@ -309,7 +343,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
309343 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
310344 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8 (r0_x_q + i * x_qblk_size );
311345
312- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
346+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
313347
314348 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
315349 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -326,7 +360,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
326360 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
327361 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8 (r0_x_q + i * x_qblk_size );
328362
329- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
363+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r0_q , vy_q , nloe ));
330364
331365 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
332366 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -393,8 +427,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
393427 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8 (r0_x_q + i * x_qblk_size );
394428 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8 (r1_x_q + i * x_qblk_size );
395429
396- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
397- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
430+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
431+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r1_q , vy_q ));
398432
399433 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
400434 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -416,8 +450,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
416450 HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8 (r0_x_q + i * x_qblk_size );
417451 HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8 (r1_x_q + i * x_qblk_size );
418452
419- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
420- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
453+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r0_q , vy_q , nloe ));
454+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r1_q , vy_q , nloe ));
421455
422456 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
423457 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -482,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
482516 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
483517 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8 (r0_x_q + i * x_qblk_size );
484518
485- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
519+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
486520
487521 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
488522 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -499,7 +533,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
499533 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
500534 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8 (r0_x_q + i * x_qblk_size );
501535
502- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
536+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r0_q , vy_q , nloe ));
503537
504538 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
505539 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -566,8 +600,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
566600 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8 (r0_x_q + i * x_qblk_size );
567601 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8 (r1_x_q + i * x_qblk_size );
568602
569- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
570- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
603+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
604+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r1_q , vy_q ));
571605
572606 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
573607 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -589,8 +623,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
589623 HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8 (r0_x_q + i * x_qblk_size );
590624 HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8 (r1_x_q + i * x_qblk_size );
591625
592- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
593- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
626+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r0_q , vy_q , nloe ));
627+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_nloe (r1_q , vy_q , nloe ));
594628
595629 HVX_Vector vy_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (y_d + i * y_dblk_size ));
596630 HVX_Vector r0_d = Q6_Vh_vshuff_Vh (* (const HVX_UVector * ) (r0_x_d + i * x_dblk_size ));
@@ -658,7 +692,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
658692 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
659693 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8 (r0_x_q + i * x_qblk_size );
660694
661- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
695+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
662696
663697 HVX_Vector vy_d = * (const HVX_UVector * ) (y_d + i * y_dblk_size );
664698 HVX_Vector r0_d = * (const HVX_UVector * ) (r0_x_d + i * x_dblk_size );
@@ -690,7 +724,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
690724 HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8 (y_q + i * y_qblk_size );
691725 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8 (r0_x_q + i * x_qblk_size );
692726
693- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
727+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
694728
695729 HVX_Vector vy_d = * (const HVX_UVector * ) (y_d + i * y_dblk_size );
696730 HVX_Vector r0_d = * (const HVX_UVector * ) (r0_x_d + i * x_dblk_size );
@@ -772,8 +806,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
772806 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8 (r0_x_q + i * x_qblk_size );
773807 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8 (r1_x_q + i * x_qblk_size );
774808
775- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
776- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
809+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
810+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r1_q , vy_q ));
777811
778812 HVX_Vector vy_d = * (const HVX_UVector * ) (y_d + i * y_dblk_size );
779813 HVX_Vector r0_d = * (const HVX_UVector * ) (r0_x_d + i * x_dblk_size );
@@ -813,8 +847,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
813847 HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8 (r0_x_q + i * x_qblk_size );
814848 HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8 (r1_x_q + i * x_qblk_size );
815849
816- HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r0_q , vy_q ));
817- HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8 (r1_q , vy_q ));
850+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r0_q , vy_q ));
851+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw (hvx_vec_rmpy_x8_full (r1_q , vy_q ));
818852
819853 HVX_Vector vy_d = * (const HVX_UVector * ) (y_d + i * y_dblk_size );
820854 HVX_Vector r0_d = * (const HVX_UVector * ) (r0_x_d + i * x_dblk_size );
0 commit comments