99
1010// Aligned step using svptrue_b8()
1111inline void L2SquareStep (const uint8_t *&pVect1, const uint8_t *&pVect2, size_t &offset,
12- svint32_t &sum, const size_t chunk) {
12+ svuint32_t &sum, const size_t chunk) {
1313 svbool_t pg = svptrue_b8 ();
1414 // Note: Because all the bits are 1, the extention to 16 and 32 bits does not make a difference
1515 // Otherwise, pg should be recalculated for 16 and 32 operations
1616
1717 svuint8_t v1_ui8 = svld1_u8 (pg, pVect1 + offset); // Load uint8 vectors from pVect1
1818 svuint8_t v2_ui8 = svld1_u8 (pg, pVect2 + offset); // Load uint8 vectors from pVect2
1919
20- // Unpack to 16 bits and reinterpret as signed is necessary for the subtraction
21- // The subtraction of two 8 bits numbers can overflow
20+ svuint8_t abs_diff = svabd_u8_x (pg, v1_ui8, v2_ui8);
2221
23- svint16_t v1_16_l = svreinterpret_s16 (svunpklo_u16 (v1_ui8));
24- svint16_t v1_16_h = svreinterpret_s16 (svunpkhi_u16 (v1_ui8));
25- svint16_t v2_16_l = svreinterpret_s16 (svunpklo_u16 (v2_ui8));
26- svint16_t v2_16_h = svreinterpret_s16 (svunpkhi_u16 (v2_ui8));
27-
28- // Calculate difference and square for low part
29- svint16_t diff_l = svsub_s16_x (pg, v1_16_l, v2_16_l);
30-
31- // Unpacking to 32 bits is necessary for the multiplication
32- // The multiplication of two 16 bits numbers can overflow
33- // Maximal value of uint8 - uint8 is 255 (255 - 0)
34- // 255^2 = 65025 while int16 can hold upto 32767
35-
36- svint32_t diff32_l_l = svunpklo_s32 (diff_l);
37- svint32_t diff32_l_h = svunpkhi_s32 (diff_l);
38-
39- // Result register is the same as the accumulator for better performance
40- svint32_t sq_l = svmul_s32_x (pg, diff32_l_l, diff32_l_l);
41- sq_l = svmla_s32_x (pg, sq_l, diff32_l_h, diff32_l_h);
42-
43- svint16_t diff_h = svsub_s16_x (pg, v1_16_h, v2_16_h);
44-
45- svint32_t diff32_h_l = svunpklo_s32 (diff_h);
46- svint32_t diff32_h_h = svunpkhi_s32 (diff_h);
47-
48- // Result register is the same as the accumulator for better performance
49- svint32_t sq_h = svmul_s32_x (pg, diff32_h_l, diff32_h_l);
50- sq_h = svmla_s32_x (pg, sq_h, diff32_h_h, diff32_h_h);
51-
52- sum = svadd_s32_x (pg, sum, sq_l);
53- sum = svadd_s32_x (pg, sum, sq_h);
22+ sum = svdot_u32 (sum, abs_diff, abs_diff);
5423
5524 offset += chunk; // Move to the next set of uint8 elements
5625}
@@ -72,10 +41,10 @@ float UINT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimen
7241 // We can safely assume that the dimension is smaller than that
7342 // So using int32_t is safe
7443
75- svint32_t sum0 = svdup_s32 (0 );
76- svint32_t sum1 = svdup_s32 (0 );
77- svint32_t sum2 = svdup_s32 (0 );
78- svint32_t sum3 = svdup_s32 (0 );
44+ svuint32_t sum0 = svdup_u32 (0 );
45+ svuint32_t sum1 = svdup_u32 (0 );
46+ svuint32_t sum2 = svdup_u32 (0 );
47+ svuint32_t sum3 = svdup_u32 (0 );
7948
8049 size_t offset = 0 ;
8150 size_t num_main_blocks = dimension / chunk_size;
@@ -105,38 +74,13 @@ float UINT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimen
10574 svuint8_t v1_ui8 = svld1_u8 (pg, pVect1 + offset); // Load uint8 vectors from pVect1
10675 svuint8_t v2_ui8 = svld1_u8 (pg, pVect2 + offset); // Load uint8 vectors from pVect2
10776
108- svbool_t pg32 = svwhilelt_b32 (offset, dimension);
109-
110- svint16_t v1_16_l = svreinterpret_s16 (svunpklo_u16 (v1_ui8));
111- svint16_t v1_16_h = svreinterpret_s16 (svunpkhi_u16 (v1_ui8));
112- svint16_t v2_16_l = svreinterpret_s16 (svunpklo_u16 (v2_ui8));
113- svint16_t v2_16_h = svreinterpret_s16 (svunpkhi_u16 (v2_ui8));
114-
115- // Calculate difference and square for low part
116- svint16_t diff_l = svsub_s16_x (svwhilelt_b16 (offset, dimension), v1_16_l, v2_16_l);
117-
118- svint32_t diff32_l_l = svunpklo_s32 (diff_l);
119- svint32_t diff32_l_h = svunpkhi_s32 (diff_l);
120-
121- // Result register is the same as the accumulator for better performance
122- svint32_t sq_l = svmul_s32_x (pg32, diff32_l_l, diff32_l_l);
123- sq_l = svmla_s32_x (pg32, sq_l, diff32_l_h, diff32_l_h);
124-
125- svint16_t diff_h = svsub_s16_x (pg32, v1_16_h, v2_16_h);
126-
127- svint32_t diff32_h_l = svunpklo_s32 (diff_h);
128- svint32_t diff32_h_h = svunpkhi_s32 (diff_h);
129-
130- // Result register is the same as the accumulator for better performance
131- svint32_t sq_h = svmul_s32_x (pg32, diff32_h_l, diff32_h_l);
132- sq_h = svmla_s32_x (pg32, sq_h, diff32_h_h, diff32_h_h);
77+ svuint8_t abs_diff = svabd_u8_x (pg, v1_ui8, v2_ui8);
13378
134- sum3 = svadd_s32_m (pg32, sum3, sq_l);
135- sum3 = svadd_s32_m (pg32, sum3, sq_h);
79+ sum3 = svdot_u32 (sum3, abs_diff, abs_diff);
13680 }
13781
138- sum0 = svadd_s32_x (all, sum0, sum1);
139- sum2 = svadd_s32_x (all, sum2, sum3);
140- svint32_t sum_all = svadd_s32_x (all, sum0, sum2);
141- return svaddv_s32 (svptrue_b32 (), sum_all);
82+ sum0 = svadd_u32_x (all, sum0, sum1);
83+ sum2 = svadd_u32_x (all, sum2, sum3);
84+ svuint32_t sum_all = svadd_u32_x (all, sum0, sum2);
85+ return svaddv_u32 (svptrue_b32 (), sum_all);
14286}
0 commit comments