Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
*/
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {

// The percentage of centroids that are scored to keep recall
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2;

public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
super(state, rawVectorsReader);
}
Expand Down Expand Up @@ -89,7 +86,8 @@ CentroidIterator getCentroidIterator(
int numCentroids,
IndexInput centroids,
float[] targetQuery,
IndexInput postingListSlice
IndexInput postingListSlice,
float visitRatio
) throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
Expand All @@ -112,8 +110,11 @@ CentroidIterator getCentroidIterator(
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
centroids.seek(0L);
int numParents = centroids.readVInt();

CentroidIterator centroidIterator;
if (numParents > 0) {
// equivalent to (float) centroidsPerParentCluster / 2
float centroidOversampling = (float) fieldEntry.numCentroids() / (2 * numParents);
centroidIterator = getCentroidIteratorWithParents(
fieldInfo,
centroids,
Expand All @@ -122,7 +123,8 @@ CentroidIterator getCentroidIterator(
scorer,
quantized,
queryParams,
globalCentroidDp
globalCentroidDp,
visitRatio * centroidOversampling
);
} else {
centroidIterator = getCentroidIteratorNoParent(
Expand Down Expand Up @@ -185,13 +187,14 @@ private static CentroidIterator getCentroidIteratorWithParents(
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
float globalCentroidDp,
float centroidRatio
) throws IOException {
// build the three queues we are going to use
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
final int maxChildrenSize = centroids.readVInt();
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids);
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ abstract CentroidIterator getCentroidIterator(
int numCentroids,
IndexInput centroids,
float[] target,
IndexInput postingListSlice
IndexInput postingListSlice,
float visitRatio
) throws IOException;

private static IndexInput openDataInput(
Expand Down Expand Up @@ -252,7 +253,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target,
postListSlice
postListSlice,
visitRatio
);
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs);
long expectedDocs = 0;
Expand Down