diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index c98532d8dd8f5..1ce0e27700744 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -52,7 +52,7 @@ public class OSQScorerBenchmark { LogConfigurator.configureESLogging(); // native access requires logging to be initialized } - @Param({ "1024" }) + @Param({ "384", "782", "1024" }) int dims; int length; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 4be6ede34530a..4b899cf987600 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -18,6 +18,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; @@ -118,8 +119,22 @@ private long quantizeScore256(byte[] q) throws IOException { subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); } - // tail as bytes + // process scalar tail in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } for (; i < length; i++) { int dValue = in.readByte() & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); @@ -158,14 +173,28 @@ private long quantizeScore128(byte[] q) throws IOException { subRet1 += sum1.reduceLanes(VectorOperators.ADD); subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // tail as bytes + // process scalar tail in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } for (; i < length; i++) { int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); - subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); - subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); - subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } @@ -215,14 +244,28 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO subRet1 += sum1.reduceLanes(VectorOperators.ADD); subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); - // tail as bytes + // process scalar tail in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } for (; i < length; i++) { int dValue = in.readByte() & 0xFF; - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); - subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); - subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); - subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); } scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } @@ -281,8 +324,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO subRet2 += sum2.reduceLanes(VectorOperators.ADD); subRet3 += sum3.reduceLanes(VectorOperators.ADD); } - // tail as bytes + // process scalar tail in.seek(offset); + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); + } + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); + } for (; i < length; i++) { int dValue = in.readByte() & 0xFF; subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);