Skip to content

Commit 17b29cd

Browse files
committed
No need for a scalar tail at all
1 parent 6c7ed1a commit 17b29cd

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ static int ipByteBit512(byte[] q, byte[] d) {
238238

239239
sectionLength = INT_SPECIES_256.length();
240240
if (q.length - i >= sectionLength) {
241-
// don't unroll this, we want to catch as many as we can before going scalar
242241
IntVector acc = IntVector.zero(INT_SPECIES_256);
243242
int limit = limit(q.length, sectionLength);
244243
for (; i < limit; i += sectionLength) {
@@ -252,11 +251,8 @@ static int ipByteBit512(byte[] q, byte[] d) {
252251
sum += acc.reduceLanes(VectorOperators.ADD);
253252
}
254253

255-
if (i < q.length) {
256-
// do the tail
257-
// default implementation uses length of data vector, not query vector
258-
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i / 8);
259-
}
254+
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
255+
assert i == q.length;
260256
return sum;
261257
}
262258

@@ -295,11 +291,8 @@ static int ipByteBit256(byte[] q, byte[] d) {
295291
+ acc3.reduceLanes(VectorOperators.ADD);
296292
}
297293

298-
if (i < q.length) {
299-
// do the tail
300-
// default implementation uses length of data vector, not query vector
301-
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i / 8);
302-
}
294+
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
295+
assert i == q.length;
303296
return sum;
304297
}
305298

@@ -341,7 +334,6 @@ static float ipFloatBit512(float[] q, byte[] d) {
341334

342335
sectionLength = FLOAT_SPECIES_256.length();
343336
if (q.length - i >= sectionLength) {
344-
// don't unroll this, we want to catch as many as we can before going scalar
345337
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
346338
int limit = limit(q.length, sectionLength);
347339
for (; i < limit; i += sectionLength) {
@@ -355,12 +347,8 @@ static float ipFloatBit512(float[] q, byte[] d) {
355347
sum += acc.reduceLanes(VectorOperators.ADD);
356348
}
357349

358-
if (i < q.length) {
359-
// do the tail
360-
// default implementation uses length of data vector, not query vector
361-
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i / 8);
362-
}
363-
350+
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
351+
assert i == q.length;
364352
return sum;
365353
}
366354

@@ -397,12 +385,8 @@ static float ipFloatBit256(float[] q, byte[] d) {
397385
+ acc3.reduceLanes(VectorOperators.ADD);
398386
}
399387

400-
if (i < q.length) {
401-
// do the tail
402-
// default implementation uses length of data vector, not query vector
403-
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i / 8);
404-
}
405-
388+
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
389+
assert i == q.length;
406390
return sum;
407391
}
408392

0 commit comments

Comments
 (0)