diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 02c696a1fc561..3999622fdc52e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -40,9 +40,6 @@ */ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats { - // The percentage of centroids that are scored to keep recall - public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2; - public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { super(state, rawVectorsReader); } @@ -89,7 +86,8 @@ CentroidIterator getCentroidIterator( int numCentroids, IndexInput centroids, float[] targetQuery, - IndexInput postingListSlice + IndexInput postingListSlice, + float visitRatio ) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); @@ -112,8 +110,11 @@ CentroidIterator getCentroidIterator( final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension()); centroids.seek(0L); int numParents = centroids.readVInt(); + CentroidIterator centroidIterator; if (numParents > 0) { + // equivalent to (float) centroidsPerParentCluster / 2 + float centroidOversampling = (float) fieldEntry.numCentroids() / (2 * numParents); centroidIterator = getCentroidIteratorWithParents( fieldInfo, centroids, @@ -122,7 +123,8 @@ CentroidIterator getCentroidIterator( scorer, quantized, queryParams, - globalCentroidDp + globalCentroidDp, + visitRatio * centroidOversampling ); } else { centroidIterator = getCentroidIteratorNoParent( @@ -185,13 +187,14 @@ private static CentroidIterator getCentroidIteratorWithParents( ES92Int7VectorsScorer scorer, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryParams, - float globalCentroidDp + float globalCentroidDp, + float centroidRatio ) throws IOException { // build the three queues we are going to use final NeighborQueue parentsQueue = new NeighborQueue(numParents, true); final int maxChildrenSize = centroids.readVInt(); final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true); - final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1); + final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids); final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true); // score the parents final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE]; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index ab26df796d934..a2914682ac93f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -90,7 +90,8 @@ abstract CentroidIterator getCentroidIterator( int numCentroids, IndexInput centroids, float[] target, - IndexInput postingListSlice + IndexInput postingListSlice, + float visitRatio ) throws IOException; private static IndexInput openDataInput( @@ -252,7 +253,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector entry.numCentroids, entry.centroidSlice(ivfCentroids), target, - postListSlice + postListSlice, + visitRatio ); PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs); long expectedDocs = 0;