Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135342.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135342
summary: Add 'profile' support for knn query on HNSW with early termination
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -401,14 +398,13 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
}
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
Expand All @@ -432,24 +428,20 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
}
}
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
} else {
return docs;
}
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
}

private static float checkResults(int[][] results, int[][] nn, int topK) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,8 @@ public void testHnswEarlyTerminationQuery() {
)
.sum();
assertTrue(
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum < vectorOpsSum
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
|| (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lte vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum <= vectorOpsSum
);
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.knn.KnnSearchStrategy;
Expand Down Expand Up @@ -2366,6 +2363,7 @@ public Query createKnnQuery(
return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]");
}
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
hnswEarlyTermination &= canApplyPatienceQuery();
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
Expand Down Expand Up @@ -2410,6 +2408,13 @@ private boolean isQuantized() {
return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized();
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnBitQuery(
byte[] queryVector,
int k,
Expand All @@ -2433,11 +2438,17 @@ private Query createKnnBitQuery(
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
? new ESDiversifyingChildrenByteKnnVectorQuery(
name(),
queryVector,
filter,
k,
numCands,
parentFilter,
searchStrategy,
hnswEarlyTermination
)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand Down Expand Up @@ -2477,11 +2488,17 @@ private Query createKnnByteQuery(
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
? new ESDiversifyingChildrenByteKnnVectorQuery(
name(),
queryVector,
filter,
k,
numCands,
parentFilter,
searchStrategy,
hnswEarlyTermination
)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand All @@ -2493,23 +2510,6 @@ private Query createKnnByteQuery(
return knnQuery;
}

private Query maybeWrapPatience(Query knnQuery) {
Query finalQuery = knnQuery;
if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery);
} else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery);
}
return finalQuery;
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnFloatQuery(
float[] queryVector,
int k,
Expand Down Expand Up @@ -2586,10 +2586,7 @@ private Query createKnnFloatQuery(
parentFilter,
knnSearchStrategy
)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy, hnswEarlyTermination);
}
if (rescore) {
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESDiversifyingChildrenByteKnnVectorQuery(
String field,
Expand All @@ -28,9 +31,23 @@ public ESDiversifyingChildrenByteKnnVectorQuery(
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
this(field, query, childFilter, k, numCands, parentsFilter, strategy, false);
}

public ESDiversifyingChildrenByteKnnVectorQuery(
String field,
byte[] query,
Query childFilter,
int k,
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -48,4 +65,10 @@ public void profile(QueryProfiler queryProfiler) {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnByteVectorQuery(
String field,
byte[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public Integer kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnFloatVectorQuery(
String field,
float[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public int kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Loading