From bc80a22282a348bae364e3dcff0c12ad6ad663e1 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Mon, 3 Nov 2025 11:00:59 +0100 Subject: [PATCH 1/3] DiskBBQ - Panama support for 4 bits symmetric quantization --- .../MSBitToInt4ESNextOSQVectorsScorer.java | 18 - ...MSInt4SymmetricESNextOSQVectorsScorer.java | 419 ++++++++++++++++++ .../MemorySegmentESNextOSQVectorsScorer.java | 24 +- .../PanamaESVectorizationProvider.java | 2 +- 4 files changed, 443 insertions(+), 20 deletions(-) create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java index 14b0cc3499ae5..888388ce2caa6 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSBitToInt4ESNextOSQVectorsScorer.java @@ -14,7 +14,6 @@ import jdk.incubator.vector.LongVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; @@ -31,23 +30,6 @@ /** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ final class MSBitToInt4ESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { - private static final int BULK_SIZE = MemorySegmentESNextOSQVectorsScorer.BULK_SIZE; - private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); - - private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; - - private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; - private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; - - private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; - private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; - - private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; - private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; - - private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; - private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - MSBitToInt4ESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, MemorySegment memorySegment) { super(in, dimensions, dataLength, memorySegment); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java new file mode 100644 index 0000000000000..9263c4ce56282 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MSInt4SymmetricESNextOSQVectorsScorer.java @@ -0,0 +1,419 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.VectorOperators; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.VectorUtil; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ +final class MSInt4SymmetricESNextOSQVectorsScorer extends MemorySegmentESNextOSQVectorsScorer.MemorySegmentScorer { + + MSInt4SymmetricESNextOSQVectorsScorer(IndexInput in, int dimensions, int dataLength, MemorySegment memorySegment) { + super(in, dimensions, dataLength, memorySegment); + } + + @Override + public long quantizeScore(byte[] q) throws IOException { + assert q.length == length; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return quantizeScoreSymmetric256(q); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return quantizeScoreSymmetric128(q); + } + } + return Long.MIN_VALUE; + } + + private long quantizeScoreSymmetric128(byte[] q) throws IOException { + int stripe0 = (int) quantizeScore128(q); + int stripe1 = (int) quantizeScore128(q); + int stripe2 = (int) quantizeScore128(q); + int stripe3 = (int) quantizeScore128(q); + return stripe0 + ((long) stripe1 << 1) + ((long) stripe2 << 2) + ((long) stripe3 << 3); + } + + private long quantizeScoreSymmetric256(byte[] q) throws IOException { + int stripe0 = (int) quantizeScore256(q); + int stripe1 = (int) quantizeScore256(q); + int stripe2 = (int) quantizeScore256(q); + int stripe3 = (int) quantizeScore256(q); + return stripe0 + ((long) stripe1 << 1) + ((long) stripe2 << 2) + ((long) stripe3 << 3); + } + + private long quantizeScore256(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + int size = length / 4; + if (size >= ByteVector.SPECIES_256.vectorByteSize() * 2) { + int limit = ByteVector.SPECIES_256.loopBound(size); + var sum0 = LongVector.zero(LONG_SPECIES_256); + var sum1 = LongVector.zero(LONG_SPECIES_256); + var sum2 = LongVector.zero(LONG_SPECIES_256); + var sum3 = LongVector.zero(LONG_SPECIES_256); + for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + size * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + + if (size - i >= ByteVector.SPECIES_128.vectorByteSize()) { + var sum0 = LongVector.zero(LONG_SPECIES_128); + var sum1 = LongVector.zero(LONG_SPECIES_128); + var sum2 = LongVector.zero(LONG_SPECIES_128); + var sum3 = LongVector.zero(LONG_SPECIES_128); + int limit = ByteVector.SPECIES_128.loopBound(size); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) { + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsLongs(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsLongs(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsLongs(); + var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + } + // process scalar tail + in.seek(offset); + for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); + } + for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); + } + for (; i < size; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * size] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + private long quantizeScore128(byte[] q) throws IOException { + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + int i = 0; + long offset = in.getFilePointer(); + + var sum0 = IntVector.zero(INT_SPECIES_128); + var sum1 = IntVector.zero(INT_SPECIES_128); + var sum2 = IntVector.zero(INT_SPECIES_128); + var sum3 = IntVector.zero(INT_SPECIES_128); + int size = length / 4; + int limit = ByteVector.SPECIES_128.loopBound(size); + for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) { + var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts(); + var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size).reinterpretAsInts(); + var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 2).reinterpretAsInts(); + var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + size * 3).reinterpretAsInts(); + sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT)); + sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT)); + sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT)); + sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT)); + } + subRet0 += sum0.reduceLanes(VectorOperators.ADD); + subRet1 += sum1.reduceLanes(VectorOperators.ADD); + subRet2 += sum2.reduceLanes(VectorOperators.ADD); + subRet3 += sum3.reduceLanes(VectorOperators.ADD); + // process scalar tail + in.seek(offset); + for (final int upperBound = size & -Long.BYTES; i < upperBound; i += Long.BYTES) { + final long value = in.readLong(); + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + size) & value); + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * size) & value); + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * size) & value); + } + for (final int upperBound = size & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + final int value = in.readInt(); + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + size) & value); + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * size) & value); + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * size) & value); + } + for (; i < size; i++) { + int dValue = in.readByte() & 0xFF; + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); + subRet1 += Integer.bitCount((q[i + size] & dValue) & 0xFF); + subRet2 += Integer.bitCount((q[i + 2 * size] & dValue) & 0xFF); + subRet3 += Integer.bitCount((q[i + 3 * size] & dValue) & 0xFF); + } + return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); + } + + @Override + public boolean quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException { + assert q.length == length; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + quantizeScore256Bulk(q, count, scores); + return true; + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + quantizeScore128Bulk(q, count, scores); + return true; + } + } + return false; + } + + private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + scores[iter] = quantizeScoreSymmetric128(q); + } + } + + private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException { + for (int iter = 0; iter < count; iter++) { + scores[iter] = quantizeScoreSymmetric256(q); + } + } + + @Override + public float scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + assert q.length == length; + // 128 / 8 == 16 + if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { + return score256Bulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { + return score128Bulk( + q, + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } + } + return Float.NEGATIVE_INFINITY; + } + + private float score128Bulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScore128Bulk(q, BULK_SIZE, scores); + int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (; i < limit; i += FLOAT_SPECIES_128.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES_128, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax).mul(FOUR_BIT_SCALE); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES_128, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES_128, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + } + } + } + in.seek(offset + 14L * BULK_SIZE); + return maxScore; + } + + private float score256Bulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + quantizeScore256Bulk(q, BULK_SIZE, scores); + int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; + float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; + for (; i < limit; i += FLOAT_SPECIES_256.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES_256, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax).mul(FOUR_BIT_SCALE); + var targetComponentSums = ShortVector.fromMemorySegment( + SHORT_SPECIES_256, + memorySegment, + offset + 8 * BULK_SIZE + i * Short.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES_256, + memorySegment, + offset + 10 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX)); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 14L * BULK_SIZE); + return maxScore; + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index 52563fa959f27..192c970afae4d 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -54,6 +54,10 @@ public MemorySegmentESNextOSQVectorsScorer( this.memorySegment = memorySegment; if (queryBits == 4 && indexBits == 1) { this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, memorySegment); + } else if (queryBits == 4 && indexBits == 4) { + this.scorer = new MSInt4SymmetricESNextOSQVectorsScorer(in, dimensions, dataLength, memorySegment); + } else if (queryBits == 4 && indexBits == 2) { + throw new IllegalArgumentException("Only symmetric 4-bit query and 1-bit index supported"); } else { throw new IllegalArgumentException("Only asymmetric 4-bit query and 1-bit index supported"); } @@ -112,7 +116,25 @@ public float scoreBulk( ); } - abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer { + abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSDibitToInt4ESNextOSQVectorsScorer, + MSInt4SymmetricESNextOSQVectorsScorer { + + static final int BULK_SIZE = MemorySegmentESNextOSQVectorsScorer.BULK_SIZE; + static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + + static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; + static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; + + static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; + + static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; + static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; + protected final MemorySegment memorySegment; protected final IndexInput in; protected final int length; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index 17838c5bef05b..44e6937b5b23e 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -40,7 +40,7 @@ public ESNextOSQVectorsScorer newESNextOSQVectorsScorer(IndexInput input, byte q if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai && queryBits == 4 - && indexBits == 1) { + && (indexBits == 1 || indexBits == 4)) { MemorySegment ms = msai.segmentSliceOrNull(0, input.length()); if (ms != null) { return new MemorySegmentESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength, ms); From dd66ef6bf1150f0f294a372058b33333c76aa2eb Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Mon, 3 Nov 2025 11:28:23 +0100 Subject: [PATCH 2/3] leading dibit class, useless stuff removed --- .../MemorySegmentESNextOSQVectorsScorer.java | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index 192c970afae4d..936ac7cc016cf 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -25,20 +25,6 @@ /** Panamized scorer for quantized vectors stored as a {@link MemorySegment}. */ public final class MemorySegmentESNextOSQVectorsScorer extends ESNextOSQVectorsScorer { - private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; - - private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128; - private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256; - - private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; - private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; - - private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; - private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; - - private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128; - private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; - private final MemorySegment memorySegment; private final MemorySegmentScorer scorer; @@ -116,7 +102,7 @@ public float scoreBulk( ); } - abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSDibitToInt4ESNextOSQVectorsScorer, + abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSInt4SymmetricESNextOSQVectorsScorer { static final int BULK_SIZE = MemorySegmentESNextOSQVectorsScorer.BULK_SIZE; From 925524f838f7dc7ba6857f96926a3cd0a8d201d5 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 3 Nov 2025 10:34:54 +0000 Subject: [PATCH 3/3] [CI] Auto commit changes from spotless --- .../vectorization/MemorySegmentESNextOSQVectorsScorer.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java index 936ac7cc016cf..492e19dd649ad 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentESNextOSQVectorsScorer.java @@ -102,8 +102,7 @@ public float scoreBulk( ); } - abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, - MSInt4SymmetricESNextOSQVectorsScorer { + abstract static sealed class MemorySegmentScorer permits MSBitToInt4ESNextOSQVectorsScorer, MSInt4SymmetricESNextOSQVectorsScorer { static final int BULK_SIZE = MemorySegmentESNextOSQVectorsScorer.BULK_SIZE; static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);