diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index f51c550e5292e..c4cd4e8b7bdc0 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -52,7 +52,8 @@ record CmdLineArgs( int quantizeBits, VectorEncoding vectorEncoding, int dimensions, - boolean earlyTermination + boolean earlyTermination, + String mergePolicy ) implements ToXContentObject { static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); @@ -79,6 +80,7 @@ record CmdLineArgs( static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); static final ParseField SEED_FIELD = new ParseField("seed"); + static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy"); static CmdLineArgs fromXContent(XContentParser parser) throws IOException { Builder builder = PARSER.apply(parser, null); @@ -112,6 +114,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); PARSER.declareLong(Builder::setSeed, SEED_FIELD); + PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD); } @Override @@ -179,6 +182,7 @@ static class Builder { private boolean earlyTermination; private float filterSelectivity = 1f; private long seed = 1751900822751L; + private String mergePolicy = null; public Builder setDocVectors(List docVectors) { if (docVectors == null || docVectors.isEmpty()) { @@ -304,6 +308,11 @@ public Builder setSeed(long seed) { return this; } + public Builder setMergePolicy(String mergePolicy) { + this.mergePolicy = mergePolicy; + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -337,7 +346,8 @@ public CmdLineArgs build() { quantizeBits, vectorEncoding, dimensions, - earlyTermination + earlyTermination, + mergePolicy ); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index c4b0ccdfe35e3..a4ffe8fc5295d 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -15,6 +15,10 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.index.LogByteSizeMergePolicy; +import org.apache.lucene.index.MergePolicy; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.TieredMergePolicy; import org.elasticsearch.cli.ProcessInfo; import org.elasticsearch.common.Strings; import org.elasticsearch.common.logging.LogConfigurator; @@ -196,6 +200,16 @@ public static void main(String[] args) throws Exception { logger.info("Running KNN index tester with arguments: " + cmdLineArgs); Codec codec = createCodec(cmdLineArgs); Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs)); + MergePolicy mergePolicy = null; + if (cmdLineArgs.mergePolicy() != null && cmdLineArgs.mergePolicy().isEmpty() == false) { + if ("tmp".equalsIgnoreCase(cmdLineArgs.mergePolicy())) { + mergePolicy = new TieredMergePolicy(); + } else if ("lbmp".equalsIgnoreCase(cmdLineArgs.mergePolicy())) { + mergePolicy = new LogByteSizeMergePolicy(); + } else if ("no".equalsIgnoreCase(cmdLineArgs.mergePolicy())) { + mergePolicy = NoMergePolicy.INSTANCE; + } + } if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) { KnnIndexer knnIndexer = new KnnIndexer( cmdLineArgs.docVectors(), @@ -205,7 +219,8 @@ public static void main(String[] args) throws Exception { cmdLineArgs.vectorEncoding(), cmdLineArgs.dimensions(), cmdLineArgs.vectorSpace(), - cmdLineArgs.numDocs() + cmdLineArgs.numDocs(), + mergePolicy ); if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) { throw new IllegalArgumentException("Index path does not exist: " + indexPath); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index f7d00c9806c8d..f3cad02640d3c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -31,6 +31,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.MergePolicy; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FSDirectory; @@ -64,6 +65,7 @@ class KnnIndexer { private final List docsPath; private final Path indexPath; private final VectorEncoding vectorEncoding; + private final MergePolicy mergePolicy; private int dim; private final VectorSimilarityFunction similarityFunction; private final Codec codec; @@ -78,7 +80,8 @@ class KnnIndexer { VectorEncoding vectorEncoding, int dim, VectorSimilarityFunction similarityFunction, - int numDocs + int numDocs, + MergePolicy mergePolicy ) { this.docsPath = docsPath; this.indexPath = indexPath; @@ -88,6 +91,7 @@ class KnnIndexer { this.dim = dim; this.similarityFunction = similarityFunction; this.numDocs = numDocs; + this.mergePolicy = mergePolicy; } void numSegments(KnnIndexTester.Results result) { @@ -104,6 +108,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB); iwc.setUseCompoundFile(false); + if (mergePolicy != null) { + iwc.setMergePolicy(mergePolicy); + } iwc.setMaxFullFlushMergeWaitMillis(0); iwc.setInfoStream(new PrintStreamInfoStream(System.out) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 304cc57284227..9e36bc1142ac5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -24,6 +24,7 @@ import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.function.IntPredicate; @@ -48,7 +49,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect } @Override - CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) + CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery, int nProbe) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); @@ -68,6 +69,8 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind return new CentroidQueryScorer() { int currentCentroid = -1; long postingListOffset; + float diff = Float.NaN; + private final float[] centroidCorrectiveValues = new float[3]; private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES); @@ -90,11 +93,30 @@ public long postingListOffset(int centroidOrdinal) throws IOException { public void bulkScore(NeighborQueue queue) throws IOException { // TODO: bulk score centroids like we do with posting lists centroids.seek(0L); + float[] centroidsScratch = null; + if (numCentroids > nProbe) { + centroidsScratch = new float[numCentroids]; + } for (int i = 0; i < numCentroids; i++) { - queue.add(i, score()); + float score = score(); + queue.add(i, score); + if (numCentroids > nProbe) { + centroidsScratch[i] = score; + } + } + if (numCentroids > nProbe) { + Arrays.sort(centroidsScratch); + float topScore = centroidsScratch[nProbe - 1]; + float nprobeScore = centroidsScratch[0]; + diff = (topScore - nprobeScore) / topScore; } } + @Override + public float scoreRatioAtNprobe() { + return diff; + } + private float score() throws IOException { final float qcDist = scorer.int4DotProduct(quantized); centroids.readFloats(centroidCorrectiveValues, 0, 3); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 01cced04a9fcc..4179a9f309343 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -89,7 +89,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR } } - abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) + abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target, int nProbe) throws IOException; private static IndexInput openDataInput( @@ -236,22 +236,34 @@ public final void search(String field, float[] target, KnnCollector knnCollector } FieldEntry entry = fields.get(fieldInfo.number); - CentroidQueryScorer centroidQueryScorer = getCentroidScorer( - fieldInfo, - entry.numCentroids, - entry.centroidSlice(ivfCentroids), - target - ); + int numCentroids = entry.numCentroids; if (nProbe == DYNAMIC_NPROBE) { // empirically based, and a good dynamic to get decent recall while scaling a la "efSearch" // scaling by the number of centroids vs. the nearest neighbors requested // not perfect, but a comparative heuristic. // we might want to utilize the total vector count as well, but this is a good start - nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k())); + nProbe = (int) Math.round(Math.log10(numCentroids) * Math.sqrt(knnCollector.k())); // clip to be between 1 and the number of centroids - nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1); + nProbe = Math.max(Math.min(nProbe, numCentroids), 1); } + + CentroidQueryScorer centroidQueryScorer = getCentroidScorer( + fieldInfo, + numCentroids, + entry.centroidSlice(ivfCentroids), + target, + nProbe + ); + final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); + + if (centroidQueue.size() > 2 && numCentroids > nProbe) { + // If the difference is small, increase nprobe to search more centroids + if (centroidQueryScorer.scoreRatioAtNprobe() < 0.001f) { + nProbe = (int) Math.min(nProbe * 2.0, numCentroids); + } + } + PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); int centroidsVisited = 0; long expectedDocs = 0; @@ -329,6 +341,9 @@ interface CentroidQueryScorer { long postingListOffset(int centroidOrdinal) throws IOException; void bulkScore(NeighborQueue queue) throws IOException; + + float scoreRatioAtNprobe(); + } interface PostingVisitor {