Skip to content

Commit a3ef160

Browse files
authored
Merge pull request #211 from 2ooom/master
Perf improvement for dimension not of factor 4 and 16
2 parents a97ec89 + 30ac4c5 commit a3ef160

File tree

2 files changed

+87
-32
lines changed

2 files changed

+87
-32
lines changed

hnswlib/space_ip.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,36 @@ namespace hnswlib {
211211

212212
#endif
213213

214+
#if defined(USE_SSE) || defined(USE_AVX)
215+
static float
216+
InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
217+
size_t qty = *((size_t *) qty_ptr);
218+
size_t qty16 = qty >> 4 << 4;
219+
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
220+
float *pVect1 = (float *) pVect1v + qty16;
221+
float *pVect2 = (float *) pVect2v + qty16;
222+
223+
size_t qty_left = qty - qty16;
224+
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
225+
return res + res_tail - 1.0f;
226+
}
227+
228+
static float
229+
InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
230+
size_t qty = *((size_t *) qty_ptr);
231+
size_t qty4 = qty >> 2 << 2;
232+
233+
float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
234+
size_t qty_left = qty - qty4;
235+
236+
float *pVect1 = (float *) pVect1v + qty4;
237+
float *pVect2 = (float *) pVect2v + qty4;
238+
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
239+
240+
return res + res_tail - 1.0f;
241+
}
242+
#endif
243+
214244
class InnerProductSpace : public SpaceInterface<float> {
215245

216246
DISTFUNC<float> fstdistfunc_;
@@ -220,11 +250,15 @@ namespace hnswlib {
220250
InnerProductSpace(size_t dim) {
221251
fstdistfunc_ = InnerProduct;
222252
#if defined(USE_AVX) || defined(USE_SSE)
223-
if (dim % 4 == 0)
224-
fstdistfunc_ = InnerProductSIMD4Ext;
225253
if (dim % 16 == 0)
226254
fstdistfunc_ = InnerProductSIMD16Ext;
227-
#endif
255+
else if (dim % 4 == 0)
256+
fstdistfunc_ = InnerProductSIMD4Ext;
257+
else if (dim > 16)
258+
fstdistfunc_ = InnerProductSIMD16ExtResiduals;
259+
else if (dim > 4)
260+
fstdistfunc_ = InnerProductSIMD4ExtResiduals;
261+
#endif
228262
dim_ = dim;
229263
data_size_ = dim * sizeof(float);
230264
}

hnswlib/space_l2.h

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
namespace hnswlib {
55

66
static float
7-
L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) {
8-
//return *((float *)pVect2);
7+
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8+
float *pVect1 = (float *) pVect1v;
9+
float *pVect2 = (float *) pVect2v;
910
size_t qty = *((size_t *) qty_ptr);
11+
1012
float res = 0;
11-
for (unsigned i = 0; i < qty; i++) {
12-
float t = ((float *) pVect1)[i] - ((float *) pVect2)[i];
13+
for (size_t i = 0; i < qty; i++) {
14+
float t = *pVect1 - *pVect2;
15+
pVect1++;
16+
pVect2++;
1317
res += t * t;
1418
}
1519
return (res);
16-
1720
}
1821

1922
#if defined(USE_AVX)
@@ -49,10 +52,8 @@ namespace hnswlib {
4952
}
5053

5154
_mm256_store_ps(TmpRes, sum);
52-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
53-
54-
return (res);
55-
}
55+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
56+
}
5657

5758
#elif defined(USE_SSE)
5859

@@ -62,12 +63,9 @@ namespace hnswlib {
6263
float *pVect2 = (float *) pVect2v;
6364
size_t qty = *((size_t *) qty_ptr);
6465
float PORTABLE_ALIGN32 TmpRes[8];
65-
// size_t qty4 = qty >> 2;
6666
size_t qty16 = qty >> 4;
6767

6868
const float *pEnd1 = pVect1 + (qty16 << 4);
69-
// const float* pEnd2 = pVect1 + (qty4 << 2);
70-
// const float* pEnd3 = pVect1 + qty;
7169

7270
__m128 diff, v1, v2;
7371
__m128 sum = _mm_set1_ps(0);
@@ -102,10 +100,24 @@ namespace hnswlib {
102100
diff = _mm_sub_ps(v1, v2);
103101
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
104102
}
103+
105104
_mm_store_ps(TmpRes, sum);
106-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
105+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
106+
}
107+
#endif
107108

108-
return (res);
109+
#if defined(USE_SSE) || defined(USE_AVX)
110+
static float
111+
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
112+
size_t qty = *((size_t *) qty_ptr);
113+
size_t qty16 = qty >> 4 << 4;
114+
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
115+
float *pVect1 = (float *) pVect1v + qty16;
116+
float *pVect2 = (float *) pVect2v + qty16;
117+
118+
size_t qty_left = qty - qty16;
119+
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
120+
return (res + res_tail);
109121
}
110122
#endif
111123

@@ -119,10 +131,9 @@ namespace hnswlib {
119131
size_t qty = *((size_t *) qty_ptr);
120132

121133

122-
// size_t qty4 = qty >> 2;
123-
size_t qty16 = qty >> 2;
134+
size_t qty4 = qty >> 2;
124135

125-
const float *pEnd1 = pVect1 + (qty16 << 2);
136+
const float *pEnd1 = pVect1 + (qty4 << 2);
126137

127138
__m128 diff, v1, v2;
128139
__m128 sum = _mm_set1_ps(0);
@@ -136,9 +147,22 @@ namespace hnswlib {
136147
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
137148
}
138149
_mm_store_ps(TmpRes, sum);
139-
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
150+
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
151+
}
140152

141-
return (res);
153+
static float
154+
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
155+
size_t qty = *((size_t *) qty_ptr);
156+
size_t qty4 = qty >> 2 << 2;
157+
158+
float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
159+
size_t qty_left = qty - qty4;
160+
161+
float *pVect1 = (float *) pVect1v + qty4;
162+
float *pVect2 = (float *) pVect2v + qty4;
163+
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
164+
165+
return (res + res_tail);
142166
}
143167
#endif
144168

@@ -151,13 +175,14 @@ namespace hnswlib {
151175
L2Space(size_t dim) {
152176
fstdistfunc_ = L2Sqr;
153177
#if defined(USE_SSE) || defined(USE_AVX)
154-
if (dim % 4 == 0)
155-
fstdistfunc_ = L2SqrSIMD4Ext;
156178
if (dim % 16 == 0)
157179
fstdistfunc_ = L2SqrSIMD16Ext;
158-
/*else{
159-
throw runtime_error("Data type not supported!");
160-
}*/
180+
else if (dim % 4 == 0)
181+
fstdistfunc_ = L2SqrSIMD4Ext;
182+
else if (dim > 16)
183+
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
184+
else if (dim > 4)
185+
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
161186
#endif
162187
dim_ = dim;
163188
data_size_ = dim * sizeof(float);
@@ -185,10 +210,6 @@ namespace hnswlib {
185210
int res = 0;
186211
unsigned char *a = (unsigned char *) pVect1;
187212
unsigned char *b = (unsigned char *) pVect2;
188-
/*for (int i = 0; i < qty; i++) {
189-
int t = int((a)[i]) - int((b)[i]);
190-
res += t*t;
191-
}*/
192213

193214
qty = qty >> 2;
194215
for (size_t i = 0; i < qty; i++) {
@@ -241,4 +262,4 @@ namespace hnswlib {
241262
};
242263

243264

244-
}
265+
}

0 commit comments

Comments
 (0)