From 476a7d858c99ee482c0e2d66e376feac1808b14e Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 25 Jul 2025 07:51:58 +0100 Subject: [PATCH 1/2] Fix score computation in ES91Int4VectorsScorer --- .../simdvec/ES91Int4VectorsScorer.java | 2 +- .../MemorySegmentES91Int4VectorsScorer.java | 2 +- .../ES91Int4VectorScorerTests.java | 109 ++++++++++++------ 3 files changed, 74 insertions(+), 39 deletions(-) diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java index 95415cee2b090..4bab4295e9861 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java @@ -162,7 +162,7 @@ public float applyCorrections( ) { float ax = lowerInterval; // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = upperInterval - ax; + float lx = (upperInterval - ax) * FOUR_BIT_SCALE; float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java index 7aaacae89be74..20e1c0dce0ccf 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java @@ -354,7 +354,7 @@ private void applyCorrectionsBulk( memorySegment, offset + 4 * BULK_SIZE + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN - ).sub(ax); + ).sub(ax).mul(FOUR_BIT_SCALE); var targetComponentSums = ShortVector.fromMemorySegment( SHORT_SPECIES, memorySegment, diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java index 34ae512b0765e..7e71c3622a5cc 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java @@ -20,7 +20,9 @@ import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; -import static org.hamcrest.Matchers.lessThan; +import java.io.IOException; + +import static org.hamcrest.Matchers.greaterThan; public class ES91Int4VectorScorerTests extends BaseVectorizationTests { @@ -130,31 +132,59 @@ public void testInt4ScoreBulk() throws Exception { // only even dimensions are supported final int dimensions = random().nextInt(1, 1000) * 2; final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE; - final byte[] vector = new byte[ES91Int4VectorsScorer.BULK_SIZE * dimensions]; - final byte[] corrections = new byte[ES91Int4VectorsScorer.BULK_SIZE * 14]; + final float[][] vectors = new float[numVectors][dimensions]; + final int[] quantizedScratch = new int[dimensions]; + final byte[] quantizeVector = new byte[dimensions]; + final float[] centroid = new float[dimensions]; + VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + for (int i = 0; i < dimensions; i++) { + centroid[i] = random().nextFloat(); + } + if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + VectorUtil.l2normalize(centroid); + } + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); try (Directory dir = new MMapDirectory(createTempDir())) { try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + OptimizedScalarQuantizer.QuantizationResult[] results = + new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE]; for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { - for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE * dimensions; j++) { - vector[j] = (byte) random().nextInt(16); // 4-bit quantization + for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) { + for (int k = 0; k < dimensions; k++) { + vectors[i + j][k] = random().nextFloat(); + } + if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + VectorUtil.l2normalize(vectors[i + j]); + } + results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 4, centroid); + for (int k = 0; k < dimensions; k++) { + quantizeVector[k] = (byte) quantizedScratch[k]; + } + out.writeBytes(quantizeVector, 0, dimensions); } - out.writeBytes(vector, 0, vector.length); - random().nextBytes(corrections); - out.writeBytes(corrections, 0, corrections.length); + writeCorrections(results, out); } } - final byte[] query = new byte[dimensions]; + final float[] query = new float[dimensions]; + final byte[] quantizeQuery = new byte[dimensions]; for (int j = 0; j < dimensions; j++) { - query[j] = (byte) random().nextInt(16); // 4-bit quantization + query[j] = random().nextFloat(); } - OptimizedScalarQuantizer.QuantizationResult queryCorrections = new OptimizedScalarQuantizer.QuantizationResult( - random().nextFloat(), - random().nextFloat(), - random().nextFloat(), - Short.toUnsignedInt((short) random().nextInt()) + if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + VectorUtil.l2normalize(query); + } + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize( + query.clone(), + quantizedScratch, + (byte) 4, + centroid ); - float centroidDp = random().nextFloat(); - VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + for (int j = 0; j < dimensions; j++) { + quantizeQuery[j] = (byte) quantizedScratch[j]; + } + float centroidDp = VectorUtil.dotProduct(centroid, centroid); + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { // Work on a slice that has just the right number of bytes to make the test fail with an // index-out-of-bounds in case the implementation reads more than the allowed number of @@ -166,7 +196,7 @@ public void testInt4ScoreBulk() throws Exception { float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE]; for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { defaultScorer.scoreBulk( - query, + quantizeQuery, queryCorrections.lowerInterval(), queryCorrections.upperInterval(), queryCorrections.quantizedComponentSum(), @@ -176,7 +206,7 @@ public void testInt4ScoreBulk() throws Exception { scoresDefault ); panamaScorer.scoreBulk( - query, + quantizeQuery, queryCorrections.lowerInterval(), queryCorrections.upperInterval(), queryCorrections.quantizedComponentSum(), @@ -186,24 +216,12 @@ public void testInt4ScoreBulk() throws Exception { scoresPanama ); for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - if (scoresDefault[j] == scoresPanama[j]) { - continue; - } - if (scoresDefault[j] > (1000 * Byte.MAX_VALUE)) { - float diff = Math.abs(scoresDefault[j] - scoresPanama[j]); - assertThat( - "defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j], - diff / scoresDefault[j], - lessThan(1e-5f) - ); - assertThat( - "defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j], - diff / scoresPanama[j], - lessThan(1e-5f) - ); - } else { - assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f); - } + assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f); + float realSimilarity = similarityFunction.compare(vectors[i + j], query); + float accuracy = realSimilarity > scoresDefault[j] + ? scoresDefault[j] / realSimilarity + : realSimilarity / scoresDefault[j]; + assertThat(accuracy, greaterThan(0.90f)); } assertEquals(in.getFilePointer(), slice.getFilePointer()); } @@ -211,4 +229,21 @@ public void testInt4ScoreBulk() throws Exception { } } } + + private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + int targetComponentSum = correction.quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + out.writeShort((short) targetComponentSum); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + } + } } From 1215e90f62efe04b29ad0cc68c91eb3a66b205c4 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 25 Jul 2025 07:53:23 +0100 Subject: [PATCH 2/2] iter --- .../java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java index 4bab4295e9861..8fca63001003a 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java @@ -161,7 +161,6 @@ public float applyCorrections( float qcDist ) { float ax = lowerInterval; - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary float lx = (upperInterval - ax) * FOUR_BIT_SCALE; float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;