Skip to content

Commit edf7ce3

Browse files
committed
using dot
1 parent ec0b2b4 commit edf7ce3

File tree

2 files changed

+18
-129
lines changed

2 files changed

+18
-129
lines changed

src/VecSim/spaces/L2/L2_SVE_INT8.h

Lines changed: 5 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,39 +17,9 @@ inline void L2SquareStep(const int8_t *&pVect1, const int8_t *&pVect2, size_t &o
1717
svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors from pVect1
1818
svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors from pVect2
1919

20-
svint16_t v1_16_l = svunpklo_s16(v1_i8);
21-
svint16_t v1_16_h = svunpkhi_s16(v1_i8);
22-
svint16_t v2_16_l = svunpklo_s16(v2_i8);
23-
svint16_t v2_16_h = svunpkhi_s16(v2_i8);
24-
25-
// Calculate difference and square for low part
26-
svint16_t diff_l = svsub_s16_x(pg, v1_16_l, v2_16_l);
27-
28-
// Unpacking to 32 bits is necessary for the multiplication
29-
// The multiplication of two 16 bits numbers can overflow
30-
// Maximal value of int8 - int8 is 255 (127 - (-128))
31-
// 255^2 = 65025 while int16 can hold upto 32767
32-
33-
svint32_t diff32_l_l = svunpklo_s32(diff_l);
34-
svint32_t diff32_l_h = svunpkhi_s32(diff_l);
35-
36-
// Result register is the same as the accumulator for better performance
37-
svint32_t sq_l = svmul_s32_x(pg, diff32_l_l, diff32_l_l);
38-
sq_l = svmla_s32_x(pg, sq_l, diff32_l_h, diff32_l_h);
39-
40-
svint16_t diff_h = svsub_s16_x(pg, v1_16_h, v2_16_h);
41-
42-
svint32_t diff32_h_l = svunpklo_s32(diff_h);
43-
svint32_t diff32_h_h = svunpkhi_s32(diff_h);
44-
45-
// Result register is the same as the accumulator for better performance
46-
svint32_t sq_h = svmul_s32_x(pg, diff32_h_l, diff32_h_l);
47-
sq_h = svmla_s32_x(pg, sq_h, diff32_h_h, diff32_h_h);
48-
49-
// Accumulate
50-
sum = svadd_s32_x(pg, sum, sq_l);
51-
sum = svadd_s32_x(pg, sum, sq_h);
20+
svint8_t abs_diff = svabd_s8_x(pg, v1_i8, v2_i8);
5221

22+
sum = svdot_s32(sum, abs_diff, abs_diff);
5323
offset += chunk; // Move to the next set of int8 elements
5424
}
5525

@@ -103,39 +73,14 @@ float INT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimens
10373
Because Inactive lanes are set to 0 in load */
10474

10575
svbool_t pg = svwhilelt_b8_u64(offset, dimension);
106-
svbool_t pg16 = svwhilelt_b16(offset, dimension);
107-
svbool_t pg32 = svwhilelt_b32(offset, dimension);
10876

10977
svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors from pVect1
11078
svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors from pVect2
11179

112-
svint16_t v1_16_l = svunpklo_s16(v1_i8);
113-
svint16_t v1_16_h = svunpkhi_s16(v1_i8);
114-
svint16_t v2_16_l = svunpklo_s16(v2_i8);
115-
svint16_t v2_16_h = svunpkhi_s16(v2_i8);
116-
117-
// Calculate difference and square for low part
118-
svint16_t diff_l = svsub_s16_x(pg16, v1_16_l, v2_16_l);
119-
120-
svint32_t diff32_l_l = svunpklo_s32(diff_l);
121-
svint32_t diff32_l_h = svunpkhi_s32(diff_l);
122-
123-
// Result register is the same as the accumulator for better performance
124-
svint32_t sq_l = svmul_s32_x(pg32, diff32_l_l, diff32_l_l);
125-
sq_l = svmla_s32_x(pg32, sq_l, diff32_l_h, diff32_l_h);
126-
127-
svint16_t diff_h = svsub_s16_x(pg16, v1_16_h, v2_16_h);
128-
129-
svint32_t diff32_h_l = svunpklo_s32(diff_h);
130-
svint32_t diff32_h_h = svunpkhi_s32(diff_h);
131-
132-
// Result register is the same as the accumulator for better performance
133-
svint32_t sq_h = svmul_s32_x(pg32, diff32_h_l, diff32_h_l);
134-
sq_h = svmla_s32_x(pg32, sq_h, diff32_h_h, diff32_h_h);
80+
svint8_t abs_diff = svabd_s8_x(pg, v1_i8, v2_i8);
13581

136-
// Accumulate
137-
sum3 = svadd_s32_x(pg32, sum3, sq_l);
138-
sum3 = svadd_s32_x(pg32, sum3, sq_h);
82+
// Can sum with taking into account pg because svld1 will set inactive lanes to 0
83+
sum3 = svdot_s32(sum3, abs_diff, abs_diff);
13984
}
14085

14186
sum0 = svadd_s32_x(all, sum0, sum1);

src/VecSim/spaces/L2/L2_SVE_UINT8.h

Lines changed: 13 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,17 @@
99

1010
// Aligned step using svptrue_b8()
1111
inline void L2SquareStep(const uint8_t *&pVect1, const uint8_t *&pVect2, size_t &offset,
12-
svint32_t &sum, const size_t chunk) {
12+
svuint32_t &sum, const size_t chunk) {
1313
svbool_t pg = svptrue_b8();
1414
// Note: Because all the bits are 1, the extention to 16 and 32 bits does not make a difference
1515
// Otherwise, pg should be recalculated for 16 and 32 operations
1616

1717
svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors from pVect1
1818
svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors from pVect2
1919

20-
// Unpack to 16 bits and reinterpret as signed is necessary for the subtraction
21-
// The subtraction of two 8 bits numbers can overflow
20+
svuint8_t abs_diff = svabd_u8_x(pg, v1_ui8, v2_ui8);
2221

23-
svint16_t v1_16_l = svreinterpret_s16(svunpklo_u16(v1_ui8));
24-
svint16_t v1_16_h = svreinterpret_s16(svunpkhi_u16(v1_ui8));
25-
svint16_t v2_16_l = svreinterpret_s16(svunpklo_u16(v2_ui8));
26-
svint16_t v2_16_h = svreinterpret_s16(svunpkhi_u16(v2_ui8));
27-
28-
// Calculate difference and square for low part
29-
svint16_t diff_l = svsub_s16_x(pg, v1_16_l, v2_16_l);
30-
31-
// Unpacking to 32 bits is necessary for the multiplication
32-
// The multiplication of two 16 bits numbers can overflow
33-
// Maximal value of uint8 - uint8 is 255 (255 - 0)
34-
// 255^2 = 65025 while int16 can hold upto 32767
35-
36-
svint32_t diff32_l_l = svunpklo_s32(diff_l);
37-
svint32_t diff32_l_h = svunpkhi_s32(diff_l);
38-
39-
// Result register is the same as the accumulator for better performance
40-
svint32_t sq_l = svmul_s32_x(pg, diff32_l_l, diff32_l_l);
41-
sq_l = svmla_s32_x(pg, sq_l, diff32_l_h, diff32_l_h);
42-
43-
svint16_t diff_h = svsub_s16_x(pg, v1_16_h, v2_16_h);
44-
45-
svint32_t diff32_h_l = svunpklo_s32(diff_h);
46-
svint32_t diff32_h_h = svunpkhi_s32(diff_h);
47-
48-
// Result register is the same as the accumulator for better performance
49-
svint32_t sq_h = svmul_s32_x(pg, diff32_h_l, diff32_h_l);
50-
sq_h = svmla_s32_x(pg, sq_h, diff32_h_h, diff32_h_h);
51-
52-
sum = svadd_s32_x(pg, sum, sq_l);
53-
sum = svadd_s32_x(pg, sum, sq_h);
22+
sum = svdot_u32(sum, abs_diff, abs_diff);
5423

5524
offset += chunk; // Move to the next set of uint8 elements
5625
}
@@ -72,10 +41,10 @@ float UINT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimen
7241
// We can safely assume that the dimension is smaller than that
7342
// So using int32_t is safe
7443

75-
svint32_t sum0 = svdup_s32(0);
76-
svint32_t sum1 = svdup_s32(0);
77-
svint32_t sum2 = svdup_s32(0);
78-
svint32_t sum3 = svdup_s32(0);
44+
svuint32_t sum0 = svdup_u32(0);
45+
svuint32_t sum1 = svdup_u32(0);
46+
svuint32_t sum2 = svdup_u32(0);
47+
svuint32_t sum3 = svdup_u32(0);
7948

8049
size_t offset = 0;
8150
size_t num_main_blocks = dimension / chunk_size;
@@ -105,38 +74,13 @@ float UINT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimen
10574
svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors from pVect1
10675
svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors from pVect2
10776

108-
svbool_t pg32 = svwhilelt_b32(offset, dimension);
109-
110-
svint16_t v1_16_l = svreinterpret_s16(svunpklo_u16(v1_ui8));
111-
svint16_t v1_16_h = svreinterpret_s16(svunpkhi_u16(v1_ui8));
112-
svint16_t v2_16_l = svreinterpret_s16(svunpklo_u16(v2_ui8));
113-
svint16_t v2_16_h = svreinterpret_s16(svunpkhi_u16(v2_ui8));
114-
115-
// Calculate difference and square for low part
116-
svint16_t diff_l = svsub_s16_x(svwhilelt_b16(offset, dimension), v1_16_l, v2_16_l);
117-
118-
svint32_t diff32_l_l = svunpklo_s32(diff_l);
119-
svint32_t diff32_l_h = svunpkhi_s32(diff_l);
120-
121-
// Result register is the same as the accumulator for better performance
122-
svint32_t sq_l = svmul_s32_x(pg32, diff32_l_l, diff32_l_l);
123-
sq_l = svmla_s32_x(pg32, sq_l, diff32_l_h, diff32_l_h);
124-
125-
svint16_t diff_h = svsub_s16_x(pg32, v1_16_h, v2_16_h);
126-
127-
svint32_t diff32_h_l = svunpklo_s32(diff_h);
128-
svint32_t diff32_h_h = svunpkhi_s32(diff_h);
129-
130-
// Result register is the same as the accumulator for better performance
131-
svint32_t sq_h = svmul_s32_x(pg32, diff32_h_l, diff32_h_l);
132-
sq_h = svmla_s32_x(pg32, sq_h, diff32_h_h, diff32_h_h);
77+
svuint8_t abs_diff = svabd_u8_x(pg, v1_ui8, v2_ui8);
13378

134-
sum3 = svadd_s32_m(pg32, sum3, sq_l);
135-
sum3 = svadd_s32_m(pg32, sum3, sq_h);
79+
sum3 = svdot_u32(sum3, abs_diff, abs_diff);
13680
}
13781

138-
sum0 = svadd_s32_x(all, sum0, sum1);
139-
sum2 = svadd_s32_x(all, sum2, sum3);
140-
svint32_t sum_all = svadd_s32_x(all, sum0, sum2);
141-
return svaddv_s32(svptrue_b32(), sum_all);
82+
sum0 = svadd_u32_x(all, sum0, sum1);
83+
sum2 = svadd_u32_x(all, sum2, sum3);
84+
svuint32_t sum_all = svadd_u32_x(all, sum0, sum2);
85+
return svaddv_u32(svptrue_b32(), sum_all);
14286
}

0 commit comments

Comments
 (0)