diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 47c6bb99eabb5..73b591712aa67 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -15,7 +15,6 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; @@ -24,6 +23,7 @@ import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.function.IntPredicate; @@ -32,7 +32,7 @@ import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.NPROBE_OVERSAMPLE; import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA; import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; @@ -48,31 +48,114 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect } @Override - CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) + CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, int nProbe, IndexInput centroids, float[] targetQuery) throws IOException { + + final float[] centroidCorrectiveValues = new float[3]; + final int numOversampled = Math.min((int) (nProbe * NPROBE_OVERSAMPLE), numCentroids); + + // constants + final int discretizedDimensions = BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64); + final long oneBitQuantizeCentroidByteSize = (long) discretizedDimensions / 8; + final long oneBitQuantizeCentroidsLength = (long) numCentroids * (oneBitQuantizeCentroidByteSize + 3 * Float.BYTES + Short.BYTES); + final long quantizationCentroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + final long quantizeCentroidsLength = (long) numCentroids * quantizationCentroidByteSize; + final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); + + // quantize the query final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); final int[] scratch = new int[targetQuery.length]; + + // FIXME: do l2normalize here? + final float[] scratchTarget = new float[targetQuery.length]; + System.arraycopy(targetQuery, 0, scratchTarget, 0, targetQuery.length); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(scratchTarget); + } final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( - ArrayUtil.copyArray(targetQuery), + scratchTarget, scratch, (byte) 4, fieldEntry.globalCentroid() ); - final byte[] quantized = new byte[targetQuery.length]; - for (int i = 0; i < quantized.length; i++) { - quantized[i] = (byte) scratch[i]; + + // pack the quantized value the way the one bit scorer expects it as 4 bits + final byte[] oneBitQuantized = new byte[QUERY_BITS * discretizedDimensions / 8]; + transposeHalfByte(scratch, oneBitQuantized); + + // pack the quantized value the way the four bit scorer expects it as single bytes instead of 4 bits + final byte[] fourBitQuantized = new byte[targetQuery.length]; + for (int i = 0; i < fourBitQuantized.length; i++) { + fourBitQuantized[i] = (byte) scratch[i]; } - final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); - NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true); - centroids.seek(0L); - final float[] centroidCorrectiveValues = new float[3]; - for (int i = 0; i < numCentroids; i++) { - final float qcDist = scorer.int4DotProduct(quantized); + + // setup to score the centroids + final ES91OSQVectorsScorer oneBitScorer = ESVectorUtil.getES91OSQVectorsScorer(centroids, fieldInfo.getVectorDimension()); + final ES91Int4VectorsScorer fourBitScorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); + + final NeighborQueue oneBitCentroidQueue = new NeighborQueue(fieldEntry.numCentroids(), true); + final NeighborQueue fourBitCentroidQueue = new NeighborQueue(numOversampled, true); + + // FIXME: this adds complexity to this method consider moving it / discuss + OneBitCentroidScorer oneBitCentroidScorer = new OneBitCentroidScorer() { + @Override + public void bulkScore(NeighborQueue queue) throws IOException { + final float[] scores = new float[BULK_SIZE]; + + centroids.seek(0L); + + // block processing + int limit = numCentroids - BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += BULK_SIZE) { + oneBitScorer.scoreBulk( + oneBitQuantized, + queryParams.lowerInterval(), + queryParams.upperInterval(), + queryParams.quantizedComponentSum(), + queryParams.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + globalCentroidDp, + scores + ); + for (int j = 0; j < BULK_SIZE; j++) { + queue.add(i + j, scores[j]); + } + } + // process tail + for (; i < numCentroids; i++) { + queue.add(i, score()); + } + } + + private float score() throws IOException { + final float qcDist = oneBitScorer.quantizeScore(oneBitQuantized); + centroids.readFloats(centroidCorrectiveValues, 0, 3); + final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); + return oneBitScorer.score( + queryParams.lowerInterval(), + queryParams.upperInterval(), + queryParams.quantizedComponentSum(), + queryParams.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + globalCentroidDp, + centroidCorrectiveValues[0], + centroidCorrectiveValues[1], + quantizedCentroidComponentSum, + centroidCorrectiveValues[2], + qcDist + ); + } + }; + + FourBitCentroidScorer fourBitCentroidScorer = centroidOrdinal -> { + centroids.seek(oneBitQuantizeCentroidsLength + quantizationCentroidByteSize * centroidOrdinal); + final float qcDist = fourBitScorer.int4DotProduct(fourBitQuantized); centroids.readFloats(centroidCorrectiveValues, 0, 3); final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); - float score = int4QuantizedScore( + return int4QuantizedScore( qcDist, queryParams, fieldInfo.getVectorDimension(), @@ -81,24 +164,65 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde globalCentroidDp, fieldInfo.getVectorSimilarityFunction() ); - queue.add(i, score); - } - final long offset = centroids.getFilePointer(); + }; + + // populate the first set of centroids + oneBitCentroidScorer.bulkScore(oneBitCentroidQueue); + populateCentroidQueue(numOversampled, oneBitCentroidQueue, fourBitCentroidQueue, fourBitCentroidScorer); + return new CentroidIterator() { @Override public boolean hasNext() { - return queue.size() > 0; + return oneBitCentroidQueue.size() + fourBitCentroidQueue.size() > 0; } - @Override public long nextPostingListOffset() throws IOException { - int centroidOrdinal = queue.pop(); - centroids.seek(offset + (long) Long.BYTES * centroidOrdinal); + int centroidOrdinal = fourBitCentroidQueue.pop(); + if (oneBitCentroidQueue.size() > 0) { + // TODO: it may be more efficient as far as disk reads to pop a set of ordinals, + // sort them, and do a batch read of for instance the next max(0.1f * rescoreSize, 1) + int centroidOrd = oneBitCentroidQueue.pop(); + fourBitCentroidQueue.add(centroidOrd, fourBitCentroidScorer.score(centroidOrd)); + } + + centroids.seek(oneBitQuantizeCentroidsLength + quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal); return centroids.readLong(); } }; } + interface OneBitCentroidScorer { + void bulkScore(NeighborQueue queue) throws IOException; + } + + interface FourBitCentroidScorer { + float score(int centroidOrdinal) throws IOException; + } + + private static void populateCentroidQueue( + int rescoreSize, + NeighborQueue oneBitCentroidQueue, + NeighborQueue centroidQueue, + FourBitCentroidScorer centroidQueryScorer + ) throws IOException { + + if (oneBitCentroidQueue.size() == 0) { + return; + } + + int[] centroidOrdinalsToRescore = new int[Math.min(rescoreSize, oneBitCentroidQueue.size())]; + for (int i = 0; i < centroidOrdinalsToRescore.length; i++) { + centroidOrdinalsToRescore[i] = oneBitCentroidQueue.pop(); + } + // do this sort so we are seeking on disk in order + Arrays.sort(centroidOrdinalsToRescore); + + // TODO: bulk read the in chunks where possible, group up sets of contiguous ordinals + for (int i = 0; i < centroidOrdinalsToRescore.length; i++) { + centroidQueue.add(centroidOrdinalsToRescore[i], centroidQueryScorer.score(centroidOrdinalsToRescore[i])); + } + } + // TODO can we do this in off-heap blocks? private float int4QuantizedScore( float qcDist, @@ -186,7 +310,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor { centroid = new float[fieldInfo.getVectorDimension()]; scratch = new float[target.length]; quantizationScratch = new int[target.length]; - final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); + final int discretizedDimensions = BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64); quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; quantizedVectorByteSize = (discretizedDimensions / 8); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index f47ecc549831a..1603747aeb1ce 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -288,14 +288,44 @@ void writeCentroids( LongValues offsets, IndexOutput centroidOutput ) throws IOException { + // TODO do we want to store these distances as well for future use? + // TODO: sort centroids by global centroid (was doing so previously here) final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); int[] quantizedScratch = new int[fieldInfo.getVectorDimension()]; float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; - final byte[] quantized = new byte[fieldInfo.getVectorDimension()]; - // TODO do we want to store these distances as well for future use? - // TODO: sort centroids by global centroid (was doing so previously here) - // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned + final byte[] oneBitQuantized = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + final byte[] fourBitquantized = new byte[fieldInfo.getVectorDimension()]; + + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, centroidOutput); + bulkWriter.writeVectors(new QuantizedVectorValues() { + int currOrd = -1; + OptimizedScalarQuantizer.QuantizationResult corrections; + + @Override + public int count() { + return centroidSupplier.size(); + } + + @Override + public byte[] next() throws IOException { + currOrd++; + float[] centroid = centroidSupplier.centroid(currOrd); + System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); + this.corrections = osq.scalarQuantize(centroidScratch, quantizedScratch, (byte) 1, globalCentroid); + BQVectorUtils.packAsBinary(quantizedScratch, oneBitQuantized); + return oneBitQuantized; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No centroid read yet, call next first"); + } + return corrections; + } + }); + for (int i = 0; i < centroidSupplier.size(); i++) { float[] centroid = centroidSupplier.centroid(i); System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); @@ -306,9 +336,9 @@ void writeCentroids( globalCentroid ); for (int j = 0; j < quantizedScratch.length; j++) { - quantized[j] = (byte) quantizedScratch[j]; + fourBitquantized[j] = (byte) quantizedScratch[j]; } - writeQuantizedValue(centroidOutput, quantized, result); + writeQuantizedValue(centroidOutput, fourBitquantized, result); } // write the centroid offsets at the end of the file for (int i = 0; i < centroidSupplier.size(); i++) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index 7a18558703423..c97c8cc514e00 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -65,6 +65,8 @@ public class IVFVectorsFormat extends KnnVectorsFormat { public static final int DEFAULT_VECTORS_PER_CLUSTER = 384; public static final int MIN_VECTORS_PER_CLUSTER = 64; public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536 + // TODO: expose oversampling as a param? + public static final float NPROBE_OVERSAMPLE = 1f; private final int vectorPerCluster; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index b570bd83f36e4..504efc5ff4230 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -88,8 +88,13 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR } } - abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) - throws IOException; + abstract CentroidIterator getCentroidIterator( + FieldInfo fieldInfo, + int numCentroids, + int numOversampled, + IndexInput centroids, + float[] target + ) throws IOException; private static IndexInput openDataInput( SegmentReadState state, @@ -244,7 +249,14 @@ public final void search(String field, float[] target, KnnCollector knnCollector // clip to be between 1 and the number of centroids nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1); } - CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); + + CentroidIterator centroidIterator = getCentroidIterator( + fieldInfo, + entry.numCentroids, + nProbe, + entry.centroidSlice(ivfCentroids), + target + ); PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); int centroidsVisited = 0; long expectedDocs = 0;