Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@
*/
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {

// The percentage of centroids that are scored to keep recall
public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2;
// How many extra centroids we need to collect for each visited centroid for hierarchical centroids.
public final float centroidOversampling;

public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader, int centroidsPerParentCluster)
throws IOException {
super(state, rawVectorsReader);
centroidOversampling = (float) centroidsPerParentCluster / 2;
}

CentroidIterator getPostingListPrefetchIterator(CentroidIterator centroidIterator, IndexInput postingListSlice) throws IOException {
Expand Down Expand Up @@ -89,7 +91,8 @@ CentroidIterator getCentroidIterator(
int numCentroids,
IndexInput centroids,
float[] targetQuery,
IndexInput postingListSlice
IndexInput postingListSlice,
float visitRatio
) throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
Expand Down Expand Up @@ -122,7 +125,8 @@ CentroidIterator getCentroidIterator(
scorer,
quantized,
queryParams,
globalCentroidDp
globalCentroidDp,
visitRatio * centroidOversampling
);
} else {
centroidIterator = getCentroidIteratorNoParent(
Expand Down Expand Up @@ -185,13 +189,14 @@ private static CentroidIterator getCentroidIteratorWithParents(
ES92Int7VectorsScorer scorer,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryParams,
float globalCentroidDp
float globalCentroidDp,
float centroidRatio
) throws IOException {
// build the three queues we are going to use
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
final int maxChildrenSize = centroids.readVInt();
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1);
final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids);
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state));
return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state), centroidsPerParentCluster);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ abstract CentroidIterator getCentroidIterator(
int numCentroids,
IndexInput centroids,
float[] target,
IndexInput postingListSlice
IndexInput postingListSlice,
float visitRatio
) throws IOException;

private static IndexInput openDataInput(
Expand Down Expand Up @@ -252,7 +253,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target,
postListSlice
postListSlice,
visitRatio
);
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs);
long expectedDocs = 0;
Expand Down