Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135342.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135342
summary: Add 'profile' support for knn query on HNSW with early termination
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -401,11 +398,9 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
}
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
Expand All @@ -432,24 +427,19 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
}
}
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
} else {
return docs;
}
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
}

private static float checkResults(int[][] results, int[][] nn, int topK) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,8 @@ public void testHnswEarlyTerminationQuery() {
)
.sum();
assertTrue(
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum < vectorOpsSum
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
|| (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lte vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum <= vectorOpsSum
);
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.knn.KnnSearchStrategy;
Expand Down Expand Up @@ -2380,6 +2377,7 @@ public Query createKnnQuery(
return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]");
}
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
hnswEarlyTermination &= canApplyPatienceQuery();
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
Expand Down Expand Up @@ -2424,6 +2422,13 @@ private boolean isQuantized() {
return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized();
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnBitQuery(
byte[] queryVector,
int k,
Expand All @@ -2448,10 +2453,7 @@ private Query createKnnBitQuery(
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand Down Expand Up @@ -2492,10 +2494,7 @@ private Query createKnnByteQuery(
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand All @@ -2507,23 +2506,6 @@ private Query createKnnByteQuery(
return knnQuery;
}

private Query maybeWrapPatience(Query knnQuery) {
Query finalQuery = knnQuery;
if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery);
} else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery);
}
return finalQuery;
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnFloatQuery(
float[] queryVector,
int k,
Expand Down Expand Up @@ -2600,10 +2582,7 @@ private Query createKnnFloatQuery(
parentFilter,
knnSearchStrategy
)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy, hnswEarlyTermination);
}
if (rescore) {
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnByteVectorQuery(
String field,
byte[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public Integer kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnFloatVectorQuery(
String field,
float[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public int kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.HnswQueueSaturationCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;

import java.io.IOException;

public class PatienceCollectorManager implements KnnCollectorManager {
private static final double DEFAULT_SATURATION_THRESHOLD = 0.995;

private final KnnCollectorManager knnCollectorManager;
private final int patience;
private final double saturationThreshold;

PatienceCollectorManager(KnnCollectorManager knnCollectorManager, int patience, double saturationThreshold) {
this.knnCollectorManager = knnCollectorManager;
this.patience = patience;
this.saturationThreshold = saturationThreshold;
}

public static KnnCollectorManager wrap(KnnCollectorManager knnCollectorManager, int k) {
return new PatienceCollectorManager(knnCollectorManager, Math.max(7, (int) (k * 0.3)), DEFAULT_SATURATION_THRESHOLD);
}

@Override
public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException {
return new HnswQueueSaturationCollector(
knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx),
saturationThreshold,
patience
);
}

@Override
public KnnCollector newOptimisticCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx, int k)
throws IOException {
if (knnCollectorManager.isOptimistic()) {
return new HnswQueueSaturationCollector(
knnCollectorManager.newOptimisticCollector(visitLimit, searchStrategy, ctx, k),
saturationThreshold,
patience
);
} else {
return null;
}
}

@Override
public boolean isOptimistic() {
return knnCollectorManager.isOptimistic();
}
}
Loading