@@ -18,7 +18,7 @@ namespace hnswlib {
1818
1919// Favor using AVX if available.
2020 static float
21- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
21+ InnerProductSIMD4ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
2222 float PORTABLE_ALIGN32 TmpRes[8 ];
2323 float *pVect1 = (float *) pVect1v;
2424 float *pVect2 = (float *) pVect2v;
@@ -64,10 +64,12 @@ namespace hnswlib {
6464 return 1 .0f - sum;
6565}
6666
67- #elif defined(USE_SSE)
67+ #endif
68+
69+ #if defined(USE_SSE)
6870
6971 static float
70- InnerProductSIMD4Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
72+ InnerProductSIMD4ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
7173 float PORTABLE_ALIGN32 TmpRes[8 ];
7274 float *pVect1 = (float *) pVect1v;
7375 float *pVect2 = (float *) pVect2v;
@@ -128,7 +130,7 @@ namespace hnswlib {
128130#if defined(USE_AVX512)
129131
130132 static float
131- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
133+ InnerProductSIMD16ExtAVX512 (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
132134 float PORTABLE_ALIGN64 TmpRes[16 ];
133135 float *pVect1 = (float *) pVect1v;
134136 float *pVect2 = (float *) pVect2v;
@@ -157,10 +159,12 @@ namespace hnswlib {
157159 return 1 .0f - sum;
158160 }
159161
160- #elif defined(USE_AVX)
162+ #endif
163+
164+ #if defined(USE_AVX)
161165
162166 static float
163- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167+ InnerProductSIMD16ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
164168 float PORTABLE_ALIGN32 TmpRes[8 ];
165169 float *pVect1 = (float *) pVect1v;
166170 float *pVect2 = (float *) pVect2v;
@@ -195,10 +199,12 @@ namespace hnswlib {
195199 return 1 .0f - sum;
196200 }
197201
198- #elif defined(USE_SSE)
202+ #endif
203+
204+ #if defined(USE_SSE)
199205
200206 static float
201- InnerProductSIMD16Ext (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
207+ InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
202208 float PORTABLE_ALIGN32 TmpRes[8 ];
203209 float *pVect1 = (float *) pVect1v;
204210 float *pVect2 = (float *) pVect2v;
@@ -245,6 +251,9 @@ namespace hnswlib {
245251#endif
246252
247253#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254+ DISTFUNC<float > InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
255+ DISTFUNC<float > InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
256+
248257 static float
249258 InnerProductSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
250259 size_t qty = *((size_t *) qty_ptr);
@@ -283,6 +292,20 @@ namespace hnswlib {
283292 InnerProductSpace (size_t dim) {
284293 fstdistfunc_ = InnerProduct;
285294 #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
295+ #if defined(USE_AVX512)
296+ if (AVX512Capable ())
297+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
298+ else if (AVXCapable ())
299+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
300+ #elif defined(USE_AVX)
301+ if (AVXCapable ())
302+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
303+ #endif
304+ #if defined(USE_AVX)
305+ if (AVXCapable ())
306+ InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
307+ #endif
308+
286309 if (dim % 16 == 0 )
287310 fstdistfunc_ = InnerProductSIMD16Ext;
288311 else if (dim % 4 == 0 )
0 commit comments