Skip to content

Commit 646f1b8

Browse files
committed
Get length calculations correct
1 parent c983834 commit 646f1b8

File tree

2 files changed

+64
-28
lines changed

2 files changed

+64
-28
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,23 @@ public static long ipByteBin128(byte[] q, byte[] d) {
194194
VectorShape.forBitSize(INT_SPECIES_256.vectorBitSize() / Integer.BYTES)
195195
);
196196

197+
private static int limit(int length, int sectionSize) {
198+
return length - (length % sectionSize);
199+
}
200+
197201
static int ipByteBit512(byte[] q, byte[] d) {
198202
assert q.length == d.length * Byte.SIZE;
199203
int i = 0;
200204
int sum = 0;
201205

202-
if (q.length >= INT_SPECIES_512.length() * 4) {
206+
int sectionLength = INT_SPECIES_512.length() * 4;
207+
if (q.length >= sectionLength) {
203208
IntVector acc0 = IntVector.zero(INT_SPECIES_512);
204209
IntVector acc1 = IntVector.zero(INT_SPECIES_512);
205210
IntVector acc2 = IntVector.zero(INT_SPECIES_512);
206211
IntVector acc3 = IntVector.zero(INT_SPECIES_512);
207-
for (; i < INT_SPECIES_512.loopBound(q.length); i += INT_SPECIES_512.length() * 4) {
212+
int limit = limit(q.length, sectionLength);
213+
for (; i < limit; i += sectionLength) {
208214
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
209215
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
210216
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2)
@@ -227,12 +233,14 @@ static int ipByteBit512(byte[] q, byte[] d) {
227233
+ acc3.reduceLanes(VectorOperators.ADD);
228234
}
229235

230-
if (q.length - i >= INT_SPECIES_256.length() * 4) {
236+
sectionLength = INT_SPECIES_256.length() * 4;
237+
if (q.length - i >= sectionLength) {
231238
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
232239
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
233240
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
234241
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
235-
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
242+
int limit = limit(q.length, sectionLength);
243+
for (; i < limit; i += sectionLength) {
236244
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
237245
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
238246
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
@@ -257,7 +265,8 @@ static int ipByteBit512(byte[] q, byte[] d) {
257265

258266
if (i < q.length) {
259267
// do the tail
260-
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
268+
// default implementation uses length of data vector, not query vector
269+
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i / 8);
261270
}
262271
return sum;
263272
}
@@ -267,12 +276,14 @@ static int ipByteBit256(byte[] q, byte[] d) {
267276
int i = 0;
268277
int sum = 0;
269278

270-
if (q.length >= INT_SPECIES_256.length() * 4) {
279+
int sectionLength = INT_SPECIES_256.length() * 4;
280+
if (q.length >= sectionLength) {
271281
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
272282
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
273283
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
274284
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
275-
for (; i < INT_SPECIES_256.loopBound(q.length); i += INT_SPECIES_256.length() * 4) {
285+
int limit = limit(q.length, sectionLength);
286+
for (; i < limit; i += sectionLength) {
276287
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
277288
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
278289
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
@@ -297,7 +308,8 @@ static int ipByteBit256(byte[] q, byte[] d) {
297308

298309
if (i < q.length) {
299310
// do the tail
300-
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
311+
// default implementation uses length of data vector, not query vector
312+
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i / 8);
301313
}
302314
return sum;
303315
}
@@ -310,12 +322,14 @@ static float ipFloatBit512(float[] q, byte[] d) {
310322
int i = 0;
311323
float sum = 0;
312324

313-
if (q.length >= FLOAT_SPECIES_512.length() * 4) {
325+
int sectionLength = FLOAT_SPECIES_512.length() * 4;
326+
if (q.length >= sectionLength) {
314327
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
315328
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
316329
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
317330
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
318-
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length() * 4) {
331+
int limit = limit(q.length, sectionLength);
332+
for (; i < limit; i += sectionLength) {
319333
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
320334
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
321335
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
@@ -336,12 +350,14 @@ static float ipFloatBit512(float[] q, byte[] d) {
336350
+ acc3.reduceLanes(VectorOperators.ADD);
337351
}
338352

339-
if (q.length - i >= FLOAT_SPECIES_256.length() * 4) {
353+
sectionLength = FLOAT_SPECIES_256.length() * 4;
354+
if (q.length - i >= sectionLength) {
340355
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
341356
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
342357
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
343358
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
344-
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
359+
int limit = limit(q.length, sectionLength);
360+
for (; i < limit; i += sectionLength) {
345361
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
346362
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
347363
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
@@ -364,7 +380,8 @@ static float ipFloatBit512(float[] q, byte[] d) {
364380

365381
if (i < q.length) {
366382
// do the tail
367-
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);
383+
// default implementation uses length of data vector, not query vector
384+
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i / 8);
368385
}
369386

370387
return sum;
@@ -375,12 +392,14 @@ static float ipFloatBit256(float[] q, byte[] d) {
375392
int i = 0;
376393
float sum = 0;
377394

378-
if (q.length >= FLOAT_SPECIES_256.length() * 4) {
395+
int sectionLength = FLOAT_SPECIES_256.length() * 4;
396+
if (q.length >= sectionLength) {
379397
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
380398
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
381399
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
382400
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
383-
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length() * 4) {
401+
int limit = limit(q.length, sectionLength);
402+
for (; i < limit; i += sectionLength) {
384403
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
385404
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
386405
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
@@ -403,7 +422,8 @@ static float ipFloatBit256(float[] q, byte[] d) {
403422

404423
if (i < q.length) {
405424
// do the tail
406-
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);
425+
// default implementation uses length of data vector, not query vector
426+
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i / 8);
407427
}
408428

409429
return sum;

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,41 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
2222
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
2323

2424
public void testIpByteBit() {
25-
byte[] q = new byte[16];
26-
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
25+
byte[] d = new byte[random().nextInt(128)];
26+
byte[] q = new byte[d.length * 8];
27+
random().nextBytes(d);
2728
random().nextBytes(q);
28-
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
29-
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
30-
assertEquals(expected, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d));
31-
assertEquals(expected, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d));
29+
30+
int sum = 0;
31+
for (int i = 0; i < q.length; i++) {
32+
if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
33+
sum += q[i];
34+
}
35+
}
36+
37+
assertEquals(sum, ESVectorUtil.ipByteBit(q, d));
38+
assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d));
39+
assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d));
3240
}
3341

3442
public void testIpFloatBit() {
35-
float[] q = new float[16];
36-
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
43+
byte[] d = new byte[random().nextInt(128)];
44+
float[] q = new float[d.length * 8];
45+
random().nextBytes(d);
46+
47+
float sum = 0;
3748
for (int i = 0; i < q.length; i++) {
3849
q[i] = random().nextFloat();
50+
if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
51+
sum += q[i];
52+
}
3953
}
40-
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
41-
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
42-
assertEquals(expected, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), 1e-6);
43-
assertEquals(expected, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), 1e-6);
54+
55+
double delta = 1e-5 * q.length;
56+
57+
assertEquals(sum, ESVectorUtil.ipFloatBit(q, d), delta);
58+
assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
59+
assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
4460
}
4561

4662
public void testIpFloatByte() {

0 commit comments

Comments
 (0)