diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 05abbb8af7f12..50c8fd107c444 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -36,7 +36,7 @@ record CmdLineArgs( KnnIndexTester.IndexType indexType, int numCandidates, int k, - int[] nProbes, + double[] visitPercentages, int ivfClusterSize, int overSamplingFactor, int hnswM, @@ -63,7 +63,8 @@ record CmdLineArgs( static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type"); static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates"); static final ParseField K_FIELD = new ParseField("k"); - static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); + // static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); + static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage"); static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size"); static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m"); @@ -97,7 +98,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); PARSER.declareInt(Builder::setK, K_FIELD); - PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); + // PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD); + PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD); PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); @@ -132,7 +134,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); builder.field(K_FIELD.getPreferredName(), k); - builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); + // builder.field(N_PROBE_FIELD.getPreferredName(), nProbes); + builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages); builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); @@ -165,7 +168,7 @@ static class Builder { private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW; private int numCandidates = 1000; private int k = 10; - private int[] nProbes = new int[] { 10 }; + private double[] visitPercentages = new double[] { 1.0 }; private int ivfClusterSize = 1000; private int overSamplingFactor = 1; private int hnswM = 16; @@ -223,8 +226,8 @@ public Builder setK(int k) { return this; } - public Builder setNProbe(List nProbes) { - this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray(); + public Builder setVisitPercentages(List visitPercentages) { + this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray(); return this; } @@ -330,7 +333,7 @@ public CmdLineArgs build() { indexType, numCandidates, k, - nProbes, + visitPercentages, ivfClusterSize, overSamplingFactor, hnswM, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index def4e3c14c6dc..ac4d1f948e4df 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -191,9 +191,9 @@ public static void main(String[] args) throws Exception { FormattedResults formattedResults = new FormattedResults(); for (CmdLineArgs cmdLineArgs : cmdLineArgsList) { - int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0 - ? cmdLineArgs.nProbes() - : new int[] { 0 }; + double[] visitPercentages = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0 + ? cmdLineArgs.visitPercentages() + : new double[] { 0 }; String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT); Results indexResults = new Results( cmdLineArgs.docVectors().get(0).getFileName().toString(), @@ -201,8 +201,8 @@ public static void main(String[] args) throws Exception { cmdLineArgs.numDocs(), cmdLineArgs.filterSelectivity() ); - Results[] results = new Results[nProbes.length]; - for (int i = 0; i < nProbes.length; i++) { + Results[] results = new Results[visitPercentages.length]; + for (int i = 0; i < visitPercentages.length; i++) { results[i] = new Results( cmdLineArgs.docVectors().get(0).getFileName().toString(), indexType, @@ -240,8 +240,7 @@ public static void main(String[] args) throws Exception { numSegments(indexPath, indexResults); if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) { for (int i = 0; i < results.length; i++) { - int nProbe = nProbes[i]; - KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe); + KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, visitPercentages[i]); knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination()); } } @@ -293,7 +292,7 @@ public String toString() { String[] searchHeaders = { "index_name", "index_type", - "n_probe", + "visit_percentage(%)", "latency(ms)", "net_cpu_time(ms)", "avg_cpu_count", @@ -324,7 +323,7 @@ public String toString() { queryResultsArray[i] = new String[] { queryResult.indexName, queryResult.indexType, - Integer.toString(queryResult.nProbe), + String.format(Locale.ROOT, "%.2f", queryResult.visitPercentage), String.format(Locale.ROOT, "%.2f", queryResult.avgLatency), String.format(Locale.ROOT, "%.2f", queryResult.netCpuTimeMS), String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount), @@ -400,7 +399,7 @@ static class Results { long indexTimeMS; long forceMergeTimeMS; int numSegments; - int nProbe; + double visitPercentage; double avgLatency; double qps; double avgRecall; diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index bb13dd75a4d9e..4b41a2664aa97 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -107,7 +107,7 @@ class KnnSearcher { private final float selectivity; private final int topK; private final int efSearch; - private final int nProbe; + private final double visitPercentage; private final KnnIndexTester.IndexType indexType; private int dim; private final VectorSimilarityFunction similarityFunction; @@ -116,7 +116,7 @@ class KnnSearcher { private final int searchThreads; private final int numSearchers; - KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) { + KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, double visitPercentage) { this.docPath = cmdLineArgs.docVectors(); this.indexPath = indexPath; this.queryPath = cmdLineArgs.queryVectors(); @@ -131,7 +131,7 @@ class KnnSearcher { throw new IllegalArgumentException("numQueryVectors must be > 0"); } this.efSearch = cmdLineArgs.numCandidates(); - this.nProbe = nProbe; + this.visitPercentage = visitPercentage; this.indexType = cmdLineArgs.indexType(); this.searchThreads = cmdLineArgs.searchThreads(); this.numSearchers = cmdLineArgs.numSearchers(); @@ -298,7 +298,7 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th } logger.info("checking results"); int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery); - finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0; + finalResults.visitPercentage = indexType == KnnIndexTester.IndexType.IVF ? visitPercentage : 0; finalResults.avgRecall = checkResults(resultIds, nn, topK); finalResults.qps = (1000f * numQueryVectors) / elapsed; finalResults.avgLatency = (float) elapsed / numQueryVectors; @@ -424,7 +424,8 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, } int efSearch = Math.max(topK, this.efSearch); if (indexType == KnnIndexTester.IndexType.IVF) { - knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe); + float visitRatio = (float) (visitPercentage / 100); + knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, visitRatio); } else { knnQuery = new ESKnnFloatVectorQuery( VECTOR_FIELD, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index aa8921cee24c4..73cf4adb804ba 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -60,8 +60,8 @@ public class IVFVectorsFormat extends KnnVectorsFormat { ); // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. - // useful when searching with 'efSearch' type parameters instead of requiring a specific nprobe. - public static final int DYNAMIC_NPROBE = -1; + // useful when searching with 'efSearch' type parameters instead of requiring a specific ratio. + public static final float DYNAMIC_VISIT_RATIO = 0.0f; public static final int DEFAULT_VECTORS_PER_CLUSTER = 384; public static final int MIN_VECTORS_PER_CLUSTER = 64; public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536 diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 0043f78590ac1..08bb87e5e5c12 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -35,7 +35,7 @@ import java.io.IOException; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; -import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_VISIT_RATIO; /** * Reader for IVF vectors. This reader is used to read the IVF vectors from the index. @@ -222,25 +222,28 @@ public final void search(String field, float[] target, KnnCollector knnCollector percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); } int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); - int nProbe = DYNAMIC_NPROBE; + float visitRatio = DYNAMIC_VISIT_RATIO; // Search strategy may be null if this is being called from checkIndex (e.g. from a test) if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) { - nProbe = ivfSearchStrategy.getNProbe(); + visitRatio = ivfSearchStrategy.getVisitRatio(); } FieldEntry entry = fields.get(fieldInfo.number); - if (nProbe == DYNAMIC_NPROBE) { + if (visitRatio == DYNAMIC_VISIT_RATIO) { // empirically based, and a good dynamic to get decent recall while scaling a la "efSearch" - // scaling by the number of centroids vs. the nearest neighbors requested + // scaling by the number of vectors vs. the nearest neighbors requested // not perfect, but a comparative heuristic. - // we might want to utilize the total vector count as well, but this is a good start - nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k())); - // clip to be between 1 and the number of centroids - nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1); + // TODO: we might want to consider the density of the centroids as experiments shows that for fewer vectors per centroid, + // the least vectors we need to score to get a good recall. + float estimated = Math.round(Math.log10(numVectors) * Math.log10(numVectors) * (knnCollector.k())); + // clip so we visit at least one vector + visitRatio = estimated / numVectors; } + // we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here. + long maxVectorVisited = (long) (2.0 * visitRatio * numVectors); CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs); - int centroidsVisited = 0; + long expectedDocs = 0; long actualDocs = 0; // initially we visit only the "centroids to search" @@ -248,8 +251,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector // TODO do we need to handle nested doc counts similarly to how we handle // filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors? while (centroidIterator.hasNext() - && (centroidsVisited < nProbe || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) { - ++centroidsVisited; + && (maxVectorVisited > actualDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) { // todo do we actually need to know the score??? long offset = centroidIterator.nextPostingListOffset(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4edd6475b890d..cde64f54c80d5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1693,18 +1693,22 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map 100d) { throw new IllegalArgumentException( - "default_n_probe must be at least 1 or exactly -1, got: " + nProbe + " for field [" + fieldName + "]" + "default_visit_percentage must be between 0.0 and 100.0, got: " + + visitPercentage + + " for field [" + + fieldName + + "]" ); } } MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); - return new BBQIVFIndexOptions(clusterSize, nProbe, rescoreVector); + return new BBQIVFIndexOptions(clusterSize, visitPercentage, rescoreVector); } @Override @@ -2297,12 +2301,12 @@ public boolean validateDimension(int dim, boolean throwOnError) { static class BBQIVFIndexOptions extends QuantizedIndexOptions { final int clusterSize; - final int defaultNProbe; + final double defaultVisitPercentage; - BBQIVFIndexOptions(int clusterSize, int defaultNProbe, RescoreVector rescoreVector) { + BBQIVFIndexOptions(int clusterSize, double defaultVisitPercentage, RescoreVector rescoreVector) { super(VectorIndexType.BBQ_DISK, rescoreVector); this.clusterSize = clusterSize; - this.defaultNProbe = defaultNProbe; + this.defaultVisitPercentage = defaultVisitPercentage; } @Override @@ -2320,13 +2324,13 @@ public boolean updatableTo(DenseVectorIndexOptions update) { boolean doEquals(DenseVectorIndexOptions other) { BBQIVFIndexOptions that = (BBQIVFIndexOptions) other; return clusterSize == that.clusterSize - && defaultNProbe == that.defaultNProbe + && defaultVisitPercentage == that.defaultVisitPercentage && Objects.equals(rescoreVector, that.rescoreVector); } @Override int doHashCode() { - return Objects.hash(clusterSize, defaultNProbe, rescoreVector); + return Objects.hash(clusterSize, defaultVisitPercentage, rescoreVector); } @Override @@ -2339,7 +2343,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field("type", type); builder.field("cluster_size", clusterSize); - builder.field("default_n_probe", defaultNProbe); + builder.field("default_visit_percentage", defaultVisitPercentage); if (rescoreVector != null) { rescoreVector.toXContent(builder, params); } @@ -2736,6 +2740,7 @@ private Query createKnnFloatQuery( .add(filter, BooleanClause.Occur.FILTER) .build(); } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) { + float defaultVisitRatio = (float) (bbqIndexOptions.defaultVisitPercentage / 100d); knnQuery = parentFilter != null ? new DiversifyingChildrenIVFKnnFloatVectorQuery( name(), @@ -2744,9 +2749,9 @@ private Query createKnnFloatQuery( numCands, filter, parentFilter, - bbqIndexOptions.defaultNProbe + defaultVisitRatio ) - : new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, bbqIndexOptions.defaultNProbe); + : new IVFKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, defaultVisitRatio); } else { knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java index 16b32c46972bc..50d94541fe666 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java @@ -11,6 +11,7 @@ import com.carrotsearch.hppc.IntHashSet; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -50,29 +51,27 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; protected final String field; - protected final int nProbe; + protected final float providedVisitRatio; protected final int k; protected final int numCands; protected final Query filter; - protected final KnnSearchStrategy searchStrategy; protected int vectorOpsCount; - protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, int numCands, Query filter) { + protected AbstractIVFKnnVectorQuery(String field, float visitRatio, int k, int numCands, Query filter) { if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); } - if (nProbe < 1 && nProbe != -1) { - throw new IllegalArgumentException("nProbe must be at least 1 or exactly -1, got: " + nProbe); + if (visitRatio < 0.0f || visitRatio > 1.0f) { + throw new IllegalArgumentException("visitRatio must be between 0.0 and 1.0 (both inclusive), got: " + visitRatio); } if (numCands < k) { throw new IllegalArgumentException("numCands must be at least k, got: " + numCands); } this.field = field; - this.nProbe = nProbe; + this.providedVisitRatio = visitRatio; this.k = k; this.filter = filter; this.numCands = numCands; - this.searchStrategy = new IVFKnnSearchStrategy(nProbe); } @Override @@ -90,12 +89,12 @@ public boolean equals(Object o) { return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter) - && Objects.equals(nProbe, that.nProbe); + && Objects.equals(providedVisitRatio, that.providedVisitRatio); } @Override public int hashCode() { - return Objects.hash(field, k, filter, nProbe); + return Objects.hash(field, k, filter, providedVisitRatio); } @Override @@ -116,16 +115,39 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { } else { filterWeight = null; } + // we request numCands as we are using it as an approximation measure // we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates - // TODO move the logic for automatically adjusting percentages/nprobe to the query, so we can only pass + // TODO move the logic for automatically adjusting percentages to the query, so we can only pass // 2k to the collector. - KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.max(Math.round(2f * k), numCands), indexSearcher); + KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.round(2f * k), indexSearcher); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); List leafReaderContexts = reader.leaves(); + + assert this instanceof IVFKnnFloatVectorQuery; + int totalVectors = 0; + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + LeafReader leafReader = leafReaderContext.reader(); + FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(field); + if (floatVectorValues != null) { + totalVectors += floatVectorValues.size(); + } + } + + final float visitRatio; + if (providedVisitRatio == 0.0f) { + // dynamically set the percentage + float expected = (float) Math.round( + Math.log10(totalVectors) * Math.log10(totalVectors) * (Math.min(10_000, Math.max(numCands, 5 * k))) + ); + visitRatio = expected / totalVectors; + } else { + visitRatio = providedVisitRatio; + } + List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio)); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -138,8 +160,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return new KnnScoreDocQuery(topK.scoreDocs, reader); } - private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException { - TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager); + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio) + throws IOException { + TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio); IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3); int deduplicateCount = 0; for (ScoreDoc scoreDoc : results.scoreDocs) { @@ -159,12 +182,13 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollec return new TopDocs(results.totalHits, deduplicatedScoreDocs); } - TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException { + TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio) + throws IOException { final LeafReader reader = ctx.reader(); final Bits liveDocs = reader.getLiveDocs(); if (filterWeight == null) { - return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager); + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio); } Scorer scorer = filterWeight.scorer(ctx); @@ -174,14 +198,15 @@ TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorM BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc()); final int cost = acceptDocs.cardinality(); - return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager); + return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio); } abstract TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, int visitedLimit, - KnnCollectorManager knnCollectorManager + KnnCollectorManager knnCollectorManager, + float visitRatio ) throws IOException; protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQuery.java index 3b665f3ccf1d3..5df47af26a0f6 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQuery.java @@ -29,7 +29,7 @@ public class DiversifyingChildrenIVFKnnFloatVectorQuery extends IVFKnnFloatVecto * @param numCands the number of nearest neighbors to gather per shard * @param childFilter the filter to apply to the results * @param parentsFilter bitset producer for the parent documents - * @param nProbe the number of probes to use for the IVF search strategy + * @param visitRatio the ratio of documents to be scored for the IVF search strategy */ public DiversifyingChildrenIVFKnnFloatVectorQuery( String field, @@ -38,9 +38,9 @@ public DiversifyingChildrenIVFKnnFloatVectorQuery( int numCands, Query childFilter, BitSetProducer parentsFilter, - int nProbe + float visitRatio ) { - super(field, query, k, numCands, childFilter, nProbe); + super(field, query, k, numCands, childFilter, visitRatio); this.parentsFilter = parentsFilter; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java index a1168f82230c4..30b37b11005b3 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.Bits; import java.io.IOException; @@ -32,10 +33,10 @@ public class IVFKnnFloatVectorQuery extends AbstractIVFKnnVectorQuery { * @param k the number of nearest neighbors to return * @param numCands the number of nearest neighbors to gather per shard * @param filter the filter to apply to the results - * @param nProbe the number of probes to use for the IVF search strategy + * @param visitRatio the ratio of vectors to score for the IVF search strategy */ - public IVFKnnFloatVectorQuery(String field, float[] query, int k, int numCands, Query filter, int nProbe) { - super(field, nProbe, k, numCands, filter); + public IVFKnnFloatVectorQuery(String field, float[] query, int k, int numCands, Query filter, float visitRatio) { + super(field, visitRatio, k, numCands, filter); this.query = query; } @@ -77,19 +78,21 @@ protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, int visitedLimit, - KnnCollectorManager knnCollectorManager + KnnCollectorManager knnCollectorManager, + float visitRatio ) throws IOException { - KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, searchStrategy, context); - if (knnCollector == null) { - return NO_RESULTS; - } LeafReader reader = context.reader(); FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); if (floatVectorValues == null) { FloatVectorValues.checkField(reader, field); return NO_RESULTS; } - if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) { + if (floatVectorValues.size() == 0) { + return NO_RESULTS; + } + KnnSearchStrategy strategy = new IVFKnnSearchStrategy(visitRatio); + KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, strategy, context); + if (knnCollector == null) { return NO_RESULTS; } reader.searchNearestVectors(field, query, knnCollector, acceptDocs); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java index eb630ea94f44f..30fe9c5ae24a6 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java @@ -13,14 +13,14 @@ import java.util.Objects; public class IVFKnnSearchStrategy extends KnnSearchStrategy { - private final int nProbe; + private final float visitRatio; - IVFKnnSearchStrategy(int nProbe) { - this.nProbe = nProbe; + IVFKnnSearchStrategy(float visitRatio) { + this.visitRatio = visitRatio; } - public int getNProbe() { - return nProbe; + public float getVisitRatio() { + return visitRatio; } @Override @@ -28,12 +28,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; IVFKnnSearchStrategy that = (IVFKnnSearchStrategy) o; - return nProbe == that.nProbe; + return visitRatio == that.visitRatio; } @Override public int hashCode() { - return Objects.hashCode(nProbe); + return Objects.hashCode(visitRatio); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 02ef40eeda0ca..09d1ad47a1083 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -65,7 +65,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomNormalizedVector; -import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_NPROBE; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DYNAMIC_VISIT_RATIO; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.IVF_FORMAT; import static org.hamcrest.Matchers.containsString; @@ -1514,7 +1514,7 @@ public void testIVFParsing() throws IOException { .getIndexOptions(); assertEquals(3.0F, indexOptions.rescoreVector.oversample(), 0.0F); assertEquals(IVFVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER, indexOptions.clusterSize); - assertEquals(DYNAMIC_NPROBE, indexOptions.defaultNProbe); + assertEquals(DYNAMIC_VISIT_RATIO, indexOptions.defaultVisitPercentage, 0.0); } { DocumentMapper mapperService = createDocumentMapper(fieldMapping(b -> { @@ -1525,7 +1525,7 @@ public void testIVFParsing() throws IOException { b.startObject("index_options"); b.field("type", "bbq_disk"); b.field("cluster_size", 1000); - b.field("default_n_probe", 10); + b.field("default_visit_percentage", 5.0); b.field(DenseVectorFieldMapper.RescoreVector.NAME, Map.of("oversample", 2.0f)); b.endObject(); })); @@ -1536,7 +1536,7 @@ public void testIVFParsing() throws IOException { .getIndexOptions(); assertEquals(2F, indexOptions.rescoreVector.oversample(), 0.0F); assertEquals(1000, indexOptions.clusterSize); - assertEquals(10, indexOptions.defaultNProbe); + assertEquals(5.0, indexOptions.defaultVisitPercentage, 0.0); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java index e602f9098b602..71583ce813154 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java @@ -101,10 +101,10 @@ public void setUp() throws Exception { format = new IVFVectorsFormat(128, 4); } - abstract AbstractIVFKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe); + abstract AbstractIVFKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, float visitRatio); final AbstractIVFKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { - return getKnnVectorQuery(field, query, k, queryFilter, 10); + return getKnnVectorQuery(field, query, k, queryFilter, 0.05f); } final AbstractIVFKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) { @@ -275,7 +275,8 @@ public void testNonVectorField() throws IOException { /** Test bad parameters */ public void testIllegalArguments() throws IOException { expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] { 1 }, 0)); - expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] { 1 }, 1, null, 0)); + expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] { 1 }, 1, null, -1)); + expectThrows(IllegalArgumentException.class, () -> getKnnVectorQuery("xx", new float[] { 1 }, 1, null, 2)); } public void testDifferentReader() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQueryTests.java index edfe597cf961b..95581ca19653b 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/DiversifyingChildrenIVFKnnFloatVectorQueryTests.java @@ -18,7 +18,7 @@ public class DiversifyingChildrenIVFKnnFloatVectorQueryTests extends AbstractDiv @Override Query getDiversifyingChildrenKnnQuery(String fieldName, float[] queryVector, Query childFilter, int k, BitSetProducer parentBitSet) { - return new DiversifyingChildrenIVFKnnFloatVectorQuery(fieldName, queryVector, k, k, childFilter, parentBitSet, -1); + return new DiversifyingChildrenIVFKnnFloatVectorQuery(fieldName, queryVector, k, k, childFilter, parentBitSet, 0); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQueryTests.java index 2c57b6958f9ca..7de22ec3c7fa0 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQueryTests.java @@ -26,8 +26,8 @@ public class IVFKnnFloatVectorQueryTests extends AbstractIVFKnnVectorQueryTestCase { @Override - IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe) { - return new IVFKnnFloatVectorQuery(field, query, k, k, queryFilter, nProbe); + IVFKnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, float visitRatio) { + return new IVFKnnFloatVectorQuery(field, query, k, k, queryFilter, visitRatio); } @Override