44namespace hnswlib {
55
66 static float
7- L2Sqr (const void *pVect1, const void *pVect2, const void *qty_ptr) {
8- // return *((float *)pVect2);
7+ L2Sqr (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8+ float *pVect1 = (float *) pVect1v;
9+ float *pVect2 = (float *) pVect2v;
910 size_t qty = *((size_t *) qty_ptr);
11+
1012 float res = 0 ;
11- for (unsigned i = 0 ; i < qty; i++) {
12- float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
13+ for (size_t i = 0 ; i < qty; i++) {
14+ float t = *pVect1 - *pVect2;
15+ pVect1++;
16+ pVect2++;
1317 res += t * t;
1418 }
1519 return (res);
16-
1720 }
1821
1922#if defined(USE_AVX)
@@ -49,10 +52,8 @@ namespace hnswlib {
4952 }
5053
5154 _mm256_store_ps (TmpRes, sum);
52- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
53-
54- return (res);
55- }
55+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
56+ }
5657
5758#elif defined(USE_SSE)
5859
@@ -62,12 +63,9 @@ namespace hnswlib {
6263 float *pVect2 = (float *) pVect2v;
6364 size_t qty = *((size_t *) qty_ptr);
6465 float PORTABLE_ALIGN32 TmpRes[8 ];
65- // size_t qty4 = qty >> 2;
6666 size_t qty16 = qty >> 4 ;
6767
6868 const float *pEnd1 = pVect1 + (qty16 << 4 );
69- // const float* pEnd2 = pVect1 + (qty4 << 2);
70- // const float* pEnd3 = pVect1 + qty;
7169
7270 __m128 diff, v1, v2;
7371 __m128 sum = _mm_set1_ps (0 );
@@ -102,10 +100,24 @@ namespace hnswlib {
102100 diff = _mm_sub_ps (v1, v2);
103101 sum = _mm_add_ps (sum, _mm_mul_ps (diff, diff));
104102 }
103+
105104 _mm_store_ps (TmpRes, sum);
106- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
105+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
106+ }
107+ #endif
107108
108- return (res);
109+ #if defined(USE_SSE) || defined(USE_AVX)
110+ static float
111+ L2SqrSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
112+ size_t qty = *((size_t *) qty_ptr);
113+ size_t qty16 = qty >> 4 << 4 ;
114+ float res = L2SqrSIMD16Ext (pVect1v, pVect2v, &qty16);
115+ float *pVect1 = (float *) pVect1v + qty16;
116+ float *pVect2 = (float *) pVect2v + qty16;
117+
118+ size_t qty_left = qty - qty16;
119+ float res_tail = L2Sqr (pVect1, pVect2, &qty_left);
120+ return (res + res_tail);
109121 }
110122#endif
111123
@@ -119,10 +131,9 @@ namespace hnswlib {
119131 size_t qty = *((size_t *) qty_ptr);
120132
121133
122- // size_t qty4 = qty >> 2;
123- size_t qty16 = qty >> 2 ;
134+ size_t qty4 = qty >> 2 ;
124135
125- const float *pEnd1 = pVect1 + (qty16 << 2 );
136+ const float *pEnd1 = pVect1 + (qty4 << 2 );
126137
127138 __m128 diff, v1, v2;
128139 __m128 sum = _mm_set1_ps (0 );
@@ -136,9 +147,22 @@ namespace hnswlib {
136147 sum = _mm_add_ps (sum, _mm_mul_ps (diff, diff));
137148 }
138149 _mm_store_ps (TmpRes, sum);
139- float res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
150+ return TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
151+ }
140152
141- return (res);
153+ static float
154+ L2SqrSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
155+ size_t qty = *((size_t *) qty_ptr);
156+ size_t qty4 = qty >> 2 << 2 ;
157+
158+ float res = L2SqrSIMD4Ext (pVect1v, pVect2v, &qty4);
159+ size_t qty_left = qty - qty4;
160+
161+ float *pVect1 = (float *) pVect1v + qty4;
162+ float *pVect2 = (float *) pVect2v + qty4;
163+ float res_tail = L2Sqr (pVect1, pVect2, &qty_left);
164+
165+ return (res + res_tail);
142166 }
143167#endif
144168
@@ -151,13 +175,14 @@ namespace hnswlib {
151175 L2Space (size_t dim) {
152176 fstdistfunc_ = L2Sqr;
153177 #if defined(USE_SSE) || defined(USE_AVX)
154- if (dim % 4 == 0 )
155- fstdistfunc_ = L2SqrSIMD4Ext;
156178 if (dim % 16 == 0 )
157179 fstdistfunc_ = L2SqrSIMD16Ext;
158- /* else{
159- throw runtime_error("Data type not supported!");
160- }*/
180+ else if (dim % 4 == 0 )
181+ fstdistfunc_ = L2SqrSIMD4Ext;
182+ else if (dim > 16 )
183+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
184+ else if (dim > 4 )
185+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
161186 #endif
162187 dim_ = dim;
163188 data_size_ = dim * sizeof (float );
@@ -185,10 +210,6 @@ namespace hnswlib {
185210 int res = 0 ;
186211 unsigned char *a = (unsigned char *) pVect1;
187212 unsigned char *b = (unsigned char *) pVect2;
188- /* for (int i = 0; i < qty; i++) {
189- int t = int((a)[i]) - int((b)[i]);
190- res += t*t;
191- }*/
192213
193214 qty = qty >> 2 ;
194215 for (size_t i = 0 ; i < qty; i++) {
@@ -241,4 +262,4 @@ namespace hnswlib {
241262 };
242263
243264
244- }
265+ }
0 commit comments