Skip to content

Commit f5da55d

Browse files
committed
Still need a tail for 256-bit impls
1 parent 17b29cd commit f5da55d

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,21 @@ static int ipByteBit256(byte[] q, byte[] d) {
291291
+ acc3.reduceLanes(VectorOperators.ADD);
292292
}
293293

294+
sectionLength = INT_SPECIES_256.length();
295+
if (q.length - i >= sectionLength) {
296+
IntVector acc = IntVector.zero(INT_SPECIES_256);
297+
int limit = limit(q.length, sectionLength);
298+
for (; i < limit; i += sectionLength) {
299+
var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
300+
301+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
302+
var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
303+
304+
acc = acc.add(vals, mask);
305+
}
306+
sum += acc.reduceLanes(VectorOperators.ADD);
307+
}
308+
294309
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
295310
assert i == q.length;
296311
return sum;
@@ -385,6 +400,21 @@ static float ipFloatBit256(float[] q, byte[] d) {
385400
+ acc3.reduceLanes(VectorOperators.ADD);
386401
}
387402

403+
sectionLength = FLOAT_SPECIES_256.length();
404+
if (q.length - i >= sectionLength) {
405+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
406+
int limit = limit(q.length, sectionLength);
407+
for (; i < limit; i += sectionLength) {
408+
var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
409+
410+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
411+
var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
412+
413+
acc = acc.add(floats, mask);
414+
}
415+
sum += acc.reduceLanes(VectorOperators.ADD);
416+
}
417+
388418
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
389419
assert i == q.length;
390420
return sum;

0 commit comments

Comments
 (0)