diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index a00ee5c0e0205..7f11698661206 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -19,8 +19,8 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; -import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; @@ -61,14 +61,14 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( targetQueryCopy, scratch, - (byte) 4, + (byte) 7, fieldEntry.globalCentroid() ); final byte[] quantized = new byte[targetQuery.length]; for (int i = 0; i < quantized.length; i++) { quantized[i] = (byte) scratch[i]; } - final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); + final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension()); centroids.seek(0L); int numParents = centroids.readVInt(); if (numParents > 0) { @@ -90,7 +90,7 @@ private static CentroidIterator getCentroidIteratorNoParent( FieldInfo fieldInfo, IndexInput centroids, int numCentroids, - ES91Int4VectorsScorer scorer, + ES92Int7VectorsScorer scorer, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryParams, float globalCentroidDp @@ -105,7 +105,7 @@ private static CentroidIterator getCentroidIteratorNoParent( queryParams, globalCentroidDp, fieldInfo.getVectorSimilarityFunction(), - new float[ES91Int4VectorsScorer.BULK_SIZE] + new float[ES92Int7VectorsScorer.BULK_SIZE] ); long offset = centroids.getFilePointer(); return new CentroidIterator() { @@ -128,7 +128,7 @@ private static CentroidIterator getCentroidIteratorWithParents( IndexInput centroids, int numParents, int numCentroids, - ES91Int4VectorsScorer scorer, + ES92Int7VectorsScorer scorer, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryParams, float globalCentroidDp @@ -140,7 +140,7 @@ private static CentroidIterator getCentroidIteratorWithParents( final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1); final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true); // score the parents - final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE]; + final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE]; score( parentsQueue, numParents, @@ -152,7 +152,7 @@ private static CentroidIterator getCentroidIteratorWithParents( fieldInfo.getVectorSimilarityFunction(), scores ); - final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES; final long offset = centroids.getFilePointer(); final long childrenOffset = offset + (long) Long.BYTES * numParents; // populate the children's queue by reading parents one by one @@ -227,7 +227,7 @@ private static void populateOneChildrenGroup( long childrenOffset, long centroidQuantizeSize, FieldInfo fieldInfo, - ES91Int4VectorsScorer scorer, + ES92Int7VectorsScorer scorer, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryParams, float globalCentroidDp, @@ -254,16 +254,16 @@ private static void score( NeighborQueue neighborQueue, int size, int scoresOffset, - ES91Int4VectorsScorer scorer, + ES92Int7VectorsScorer scorer, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, float centroidDp, VectorSimilarityFunction similarityFunction, float[] scores ) throws IOException { - int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1; + int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1; int i = 0; - for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) { + for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) { scorer.scoreBulk( quantizeQuery, queryCorrections.lowerInterval(), @@ -274,7 +274,7 @@ private static void score( centroidDp, scores ); - for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) { + for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) { neighborQueue.add(scoresOffset + i + j, scores[j]); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 58f09cf70d4bd..ce54afbd9e90e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -26,8 +26,8 @@ import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import java.io.IOException; import java.io.UncheckedIOException; @@ -315,8 +315,8 @@ private void writeCentroidsWithParents( LongValues offsets, IndexOutput centroidOutput ) throws IOException { - DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter( - ES91Int4VectorsScorer.BULK_SIZE, + DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter( + ES92Int7VectorsScorer.BULK_SIZE, centroidOutput ); final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); @@ -365,8 +365,8 @@ private void writeCentroidsWithoutParents( IndexOutput centroidOutput ) throws IOException { centroidOutput.writeVInt(0); - DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter( - ES91Int4VectorsScorer.BULK_SIZE, + DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter( + ES92Int7VectorsScorer.BULK_SIZE, centroidOutput ); final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); @@ -571,7 +571,7 @@ public byte[] next() throws IOException { // Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice // due to overspill, so we copy the vector to a scratch array System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length); - corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 4, centroid); + corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid); for (int i = 0; i < quantizedVectorScratch.length; i++) { quantizedVector[i] = (byte) quantizedVectorScratch[i]; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 9da77fb77661a..1d7e7f74f6c14 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -29,32 +29,6 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException; - private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { - for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { - out.writeInt(Float.floatToIntBits(correction.lowerInterval())); - } - for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { - out.writeInt(Float.floatToIntBits(correction.upperInterval())); - } - for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { - int targetComponentSum = correction.quantizedComponentSum(); - assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; - out.writeShort((short) targetComponentSum); - } - for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { - out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); - } - } - - private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction, IndexOutput out) throws IOException { - out.writeInt(Float.floatToIntBits(correction.lowerInterval())); - out.writeInt(Float.floatToIntBits(correction.upperInterval())); - out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); - int targetComponentSum = correction.quantizedComponentSum(); - assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; - out.writeShort((short) targetComponentSum); - } - static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; @@ -73,22 +47,48 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx corrections[j] = qvv.getCorrections(); out.writeBytes(qv, qv.length); } - writeCorrections(corrections, out); + writeCorrections(corrections); } // write tail for (; i < qvv.count(); ++i) { byte[] qv = qvv.next(); OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); out.writeBytes(qv, qv.length); - writeCorrection(correction, out); + writeCorrection(correction); + } + } + + private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException { + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + int targetComponentSum = correction.quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + out.writeShort((short) targetComponentSum); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); } } + + private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + int targetComponentSum = correction.quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + out.writeShort((short) targetComponentSum); + } } - static class FourBitDiskBBQBulkWriter extends DiskBBQBulkWriter { + static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - FourBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @@ -103,15 +103,37 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx corrections[j] = qvv.getCorrections(); out.writeBytes(qv, qv.length); } - writeCorrections(corrections, out); + writeCorrections(corrections); } // write tail for (; i < qvv.count(); ++i) { byte[] qv = qvv.next(); OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); out.writeBytes(qv, qv.length); - writeCorrection(correction, out); + writeCorrection(correction); } } + + private void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections) throws IOException { + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(correction.quantizedComponentSum()); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + } + } + + private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correction) throws IOException { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + out.writeInt(correction.quantizedComponentSum()); + } } }