diff --git a/docs/changelog/135342.yaml b/docs/changelog/135342.yaml new file mode 100644 index 0000000000000..e00d000dd303c --- /dev/null +++ b/docs/changelog/135342.yaml @@ -0,0 +1,5 @@ +pr: 135342 +summary: Add 'profile' support for knn query on HNSW with early termination +area: Vector Search +type: enhancement +issues: [] diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index c71aa21f5dd29..f614ead62cb14 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -38,9 +38,6 @@ import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; @@ -401,14 +398,13 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery, topK, efSearch, filterQuery, - DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() + DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(), + indexType == KnnIndexTester.IndexType.HNSW && earlyTermination ); - if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { - knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery); - } } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); + assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling"; QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery; queryProfilerProvider.profile(profiler); return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); @@ -432,11 +428,9 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, topK, efSearch, filterQuery, - DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy() + DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(), + indexType == KnnIndexTester.IndexType.HNSW && earlyTermination ); - if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) { - knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery); - } } if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result @@ -444,12 +438,10 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, } QueryProfiler profiler = new QueryProfiler(); TopDocs docs = searcher.search(knnQuery, this.topK); - if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) { - queryProfilerProvider.profile(profiler); - return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); - } else { - return docs; - } + assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling"; + QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery; + queryProfilerProvider.profile(profiler); + return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs); } private static float checkResults(int[][] results, int[][] nn, int topK) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java index ffa7727c53d7e..c5df7931c6203 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java @@ -170,10 +170,8 @@ public void testHnswEarlyTerminationQuery() { ) .sum(); assertTrue( - "earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]", - earlyTerminationVectorOpsSum < vectorOpsSum - // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit - || (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1) + "earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lte vectorOps [" + vectorOpsSum + "]", + earlyTerminationVectorOpsSum <= vectorOpsSum ); } ); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index d6ed48a62b480..7b8c9934f8104 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -33,10 +33,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.knn.KnnSearchStrategy; @@ -2366,6 +2363,7 @@ public Query createKnnQuery( return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]"); } KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy(); + hnswEarlyTermination &= canApplyPatienceQuery(); return switch (getElementType()) { case BYTE -> createKnnByteQuery( queryVector.asByteVector(), @@ -2410,6 +2408,13 @@ private boolean isQuantized() { return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized(); } + private boolean canApplyPatienceQuery() { + return indexOptions instanceof HnswIndexOptions + || indexOptions instanceof Int8HnswIndexOptions + || indexOptions instanceof Int4HnswIndexOptions + || indexOptions instanceof BBQHnswIndexOptions; + } + private Query createKnnBitQuery( byte[] queryVector, int k, @@ -2433,11 +2438,17 @@ private Query createKnnBitQuery( .build(); } else { knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); - if (hnswEarlyTermination) { - knnQuery = maybeWrapPatience(knnQuery); - } + ? new ESDiversifyingChildrenByteKnnVectorQuery( + name(), + queryVector, + filter, + k, + numCands, + parentFilter, + searchStrategy, + hnswEarlyTermination + ) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination); } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( @@ -2477,11 +2488,17 @@ private Query createKnnByteQuery( .build(); } else { knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy); - if (hnswEarlyTermination) { - knnQuery = maybeWrapPatience(knnQuery); - } + ? new ESDiversifyingChildrenByteKnnVectorQuery( + name(), + queryVector, + filter, + k, + numCands, + parentFilter, + searchStrategy, + hnswEarlyTermination + ) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination); } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( @@ -2493,23 +2510,6 @@ private Query createKnnByteQuery( return knnQuery; } - private Query maybeWrapPatience(Query knnQuery) { - Query finalQuery = knnQuery; - if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) { - finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery); - } else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) { - finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery); - } - return finalQuery; - } - - private boolean canApplyPatienceQuery() { - return indexOptions instanceof HnswIndexOptions - || indexOptions instanceof Int8HnswIndexOptions - || indexOptions instanceof Int4HnswIndexOptions - || indexOptions instanceof BBQHnswIndexOptions; - } - private Query createKnnFloatQuery( float[] queryVector, int k, @@ -2586,10 +2586,7 @@ private Query createKnnFloatQuery( parentFilter, knnSearchStrategy ) - : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy); - if (hnswEarlyTermination) { - knnQuery = maybeWrapPatience(knnQuery); - } + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy, hnswEarlyTermination); } if (rescore) { knnQuery = RescoreKnnVectorQuery.fromInnerQuery( diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index ddc427868c672..82f9d740afd86 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -9,16 +9,19 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { private final int kParam; private long vectorOpsCount; + private final boolean earlyTermination; public ESDiversifyingChildrenByteKnnVectorQuery( String field, @@ -28,9 +31,23 @@ public ESDiversifyingChildrenByteKnnVectorQuery( int numCands, BitSetProducer parentsFilter, KnnSearchStrategy strategy + ) { + this(field, query, childFilter, k, numCands, parentsFilter, strategy, false); + } + + public ESDiversifyingChildrenByteKnnVectorQuery( + String field, + byte[] query, + Query childFilter, + int k, + int numCands, + BitSetProducer parentsFilter, + KnnSearchStrategy strategy, + boolean earlyTermination ) { super(field, query, childFilter, numCands, parentsFilter, strategy); this.kParam = k; + this.earlyTermination = earlyTermination; } @Override @@ -48,4 +65,10 @@ public void profile(QueryProfiler queryProfiler) { public KnnSearchStrategy getStrategy() { return searchStrategy; } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher); + return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 6e90da12bd7e7..4687fc8db2986 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -9,19 +9,35 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final int kParam; private long vectorOpsCount; + private final boolean earlyTermination; public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) { + this(field, target, k, numCands, filter, strategy, false); + } + + public ESKnnByteVectorQuery( + String field, + byte[] target, + int k, + int numCands, + Query filter, + KnnSearchStrategy strategy, + boolean earlyTermination + ) { super(field, target, numCands, filter, strategy); this.kParam = k; + this.earlyTermination = earlyTermination; } @Override @@ -44,4 +60,10 @@ public Integer kParam() { public KnnSearchStrategy getStrategy() { return searchStrategy; } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher); + return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index 04f6104476c51..30d826b05be1b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -9,19 +9,35 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.elasticsearch.search.profile.query.QueryProfiler; public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final int kParam; private long vectorOpsCount; + private final boolean earlyTermination; public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) { + this(field, target, k, numCands, filter, strategy, false); + } + + public ESKnnFloatVectorQuery( + String field, + float[] target, + int k, + int numCands, + Query filter, + KnnSearchStrategy strategy, + boolean earlyTermination + ) { super(field, target, numCands, filter, strategy); this.kParam = k; + this.earlyTermination = earlyTermination; } @Override @@ -44,4 +60,10 @@ public int kParam() { public KnnSearchStrategy getStrategy() { return searchStrategy; } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher); + return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/PatienceCollectorManager.java b/server/src/main/java/org/elasticsearch/search/vectors/PatienceCollectorManager.java new file mode 100644 index 0000000000000..930046b693417 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/PatienceCollectorManager.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.HnswQueueSaturationCollector; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; + +import java.io.IOException; + +/** + * This is a decorator for the {@link KnnCollectorManager} that early terminates the wrapped {@link KnnCollector} + * based on a saturation threshold and a patience factor. It is designed + * to improve the efficiency of approximate nearest neighbor (KNN) searches by monitoring queue saturation + * during the search process. + * This applies a patience-based logic to both optimistic and regular KNN collectors. + * The saturation threshold defines the percentage of saturation at which the collector's patience is + * tested for termination. + */ +class PatienceCollectorManager implements KnnCollectorManager { + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995; + + private final KnnCollectorManager knnCollectorManager; + private final int patience; + private final double saturationThreshold; + + PatienceCollectorManager(KnnCollectorManager knnCollectorManager, int patience, double saturationThreshold) { + this.knnCollectorManager = knnCollectorManager; + this.patience = patience; + this.saturationThreshold = saturationThreshold; + } + + static KnnCollectorManager wrap(KnnCollectorManager knnCollectorManager, int k) { + return new PatienceCollectorManager(knnCollectorManager, Math.max(7, (int) (k * 0.3)), DEFAULT_SATURATION_THRESHOLD); + } + + @Override + public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException { + return new HnswQueueSaturationCollector( + knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx), + saturationThreshold, + patience + ); + } + + @Override + public KnnCollector newOptimisticCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx, int k) + throws IOException { + if (knnCollectorManager.isOptimistic()) { + return new HnswQueueSaturationCollector( + knnCollectorManager.newOptimisticCollector(visitLimit, searchStrategy, ctx, k), + saturationThreshold, + patience + ); + } else { + return null; + } + } + + @Override + public boolean isOptimistic() { + return knnCollectorManager.isOptimistic(); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 43923a77b1c3e..b56b66767c7d7 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -10,7 +10,6 @@ package org.elasticsearch.index.mapper.vectors; import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.PatienceKnnVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; @@ -265,9 +264,7 @@ public void testCreateNestedKnnQuery() { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { assertTrue( - query instanceof DiversifyingChildrenFloatKnnVectorQuery - || query instanceof PatienceKnnVectorQuery - || query instanceof DiversifyingChildrenIVFKnnFloatVectorQuery + query instanceof DiversifyingChildrenFloatKnnVectorQuery || query instanceof DiversifyingChildrenIVFKnnFloatVectorQuery ); } } @@ -305,7 +302,7 @@ public void testCreateNestedKnnQuery() { if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery); } vectorData = new VectorData(floatQueryVector, null); @@ -324,7 +321,7 @@ public void testCreateNestedKnnQuery() { if (field.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DiversifyingParentBlockQuery.class)); } else { - assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery); } } } @@ -499,11 +496,7 @@ public void testCreateKnnQueryMaxDims() { if (fieldWith4096dims.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); } else { - assertTrue( - query instanceof KnnFloatVectorQuery - || query instanceof PatienceKnnVectorQuery - || query instanceof IVFKnnFloatVectorQuery - ); + assertTrue(query instanceof KnnFloatVectorQuery || query instanceof IVFKnnFloatVectorQuery); } } @@ -539,7 +532,7 @@ public void testCreateKnnQueryMaxDims() { if (fieldWith4096dims.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Bytes.class)); } else { - assertTrue(query instanceof ESKnnByteVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue(query instanceof ESKnnByteVectorQuery); } } } @@ -650,25 +643,17 @@ public void testRescoreOversampleUsedWithoutQuantization() { if (nonQuantizedField.getIndexOptions().isFlat()) { assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class)); } else { - if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { - assertThat(patienceKnnVectorQuery.getK(), is(100)); - } else { - ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery; - assertThat(knnByteVectorQuery.getK(), is(100)); - assertThat(knnByteVectorQuery.kParam(), is(10)); - } + ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(knnByteVectorQuery.getK(), is(100)); + assertThat(knnByteVectorQuery.kParam(), is(10)); } } else { if (nonQuantizedField.getIndexOptions().isFlat()) { assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class)); } else { - if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { - assertThat(patienceKnnVectorQuery.getK(), is(100)); - } else { - ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(knnFloatVectorQuery.getK(), is(100)); - assertThat(knnFloatVectorQuery.kParam(), is(10)); - } + ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(knnFloatVectorQuery.getK(), is(100)); + assertThat(knnFloatVectorQuery.kParam(), is(10)); } } } @@ -722,7 +707,7 @@ public void testRescoreOversampleQueryOverrides() { if (fieldType.getIndexOptions().isFlat()) { assertThat(query, instanceOf(DenseVectorQuery.Floats.class)); } else { - assertTrue(query instanceof ESKnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery); + assertTrue(query instanceof ESKnnFloatVectorQuery); } // verify we can override a `0` to a positive number @@ -760,9 +745,9 @@ public void testRescoreOversampleQueryOverrides() { public void testFilterSearchThreshold() { List>> cases = List.of( - Tuple.tuple(FLOAT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnFloatVectorQuery) q).getStrategy()), - Tuple.tuple(BYTE, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()), - Tuple.tuple(BIT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()) + Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()), + Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()), + Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy()) ); for (var tuple : cases) { DenseVectorFieldType fieldType = new DenseVectorFieldType( @@ -840,13 +825,9 @@ private static void checkRescoreQueryParameters( ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; Query innerQuery = rescoreQuery.innerQuery(); - if (innerQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { - assertThat("Unexpected candidates", patienceKnnVectorQuery.getK(), equalTo(expectedCandidates)); - } else { - ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery; - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); - assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates)); - assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK)); - } + ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery; + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates)); + assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK)); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/PatienceCollectorManagerTests.java b/server/src/test/java/org/elasticsearch/search/vectors/PatienceCollectorManagerTests.java new file mode 100644 index 0000000000000..1e48bc5654dbc --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/PatienceCollectorManagerTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.HnswQueueSaturationCollector; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.search.knn.TopKnnCollectorManager; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +public class PatienceCollectorManagerTests extends ESTestCase { + + public void testEarlyTermination() throws IOException { + int k = randomIntBetween(1, 10); + int patience = randomIntBetween(1, 2); + double saturationThreshold = randomDoubleBetween(0.01, 0.02, true); + PatienceCollectorManager patienceCollectorManager = new PatienceCollectorManager( + new TopKnnCollectorManager(k, null), + patience, + saturationThreshold + ); + HnswQueueSaturationCollector knnCollector = (HnswQueueSaturationCollector) patienceCollectorManager.newCollector( + randomIntBetween(100, 1000), + new KnnSearchStrategy.Hnsw(10), + null + ); + + for (int i = 0; i < 100; i++) { + knnCollector.collect(i, 1 - i); + if (i % 10 == 0) { + knnCollector.nextCandidate(); + } + } + assertTrue(knnCollector.earlyTerminated()); + } +}