Skip to content

Commit ab36139

Browse files
committed
Compute visitRatio globally when doing it dynamically
1 parent d0019c2 commit ab36139

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import com.carrotsearch.hppc.IntHashSet;
1313

14+
import org.apache.lucene.index.FloatVectorValues;
1415
import org.apache.lucene.index.IndexReader;
1516
import org.apache.lucene.index.LeafReader;
1617
import org.apache.lucene.index.LeafReaderContext;
@@ -50,11 +51,10 @@ abstract class AbstractIVFKnnVectorQuery extends Query implements QueryProfilerP
5051
static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
5152

5253
protected final String field;
53-
protected final float visitRatio;
54+
protected final float providedVisitRatio;
5455
protected final int k;
5556
protected final int numCands;
5657
protected final Query filter;
57-
protected final IVFKnnSearchStrategy searchStrategy;
5858
protected int vectorOpsCount;
5959

6060
protected AbstractIVFKnnVectorQuery(String field, float visitRatio, int k, int numCands, Query filter) {
@@ -68,11 +68,10 @@ protected AbstractIVFKnnVectorQuery(String field, float visitRatio, int k, int n
6868
throw new IllegalArgumentException("numCands must be at least k, got: " + numCands);
6969
}
7070
this.field = field;
71-
this.visitRatio = visitRatio;
71+
this.providedVisitRatio = visitRatio;
7272
this.k = k;
7373
this.filter = filter;
7474
this.numCands = numCands;
75-
this.searchStrategy = new IVFKnnSearchStrategy(visitRatio);
7675
}
7776

7877
@Override
@@ -90,12 +89,12 @@ public boolean equals(Object o) {
9089
return k == that.k
9190
&& Objects.equals(field, that.field)
9291
&& Objects.equals(filter, that.filter)
93-
&& Objects.equals(visitRatio, that.visitRatio);
92+
&& Objects.equals(providedVisitRatio, that.providedVisitRatio);
9493
}
9594

9695
@Override
9796
public int hashCode() {
98-
return Objects.hash(field, k, filter, visitRatio);
97+
return Objects.hash(field, k, filter, providedVisitRatio);
9998
}
10099

101100
@Override
@@ -116,16 +115,36 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
116115
} else {
117116
filterWeight = null;
118117
}
118+
119119
// we request numCands as we are using it as an approximation measure
120120
// we need to ensure we are getting at least 2*k results to ensure we cover overspill duplicates
121121
// TODO move the logic for automatically adjusting percentages to the query, so we can only pass
122122
// 2k to the collector.
123123
KnnCollectorManager knnCollectorManager = getKnnCollectorManager(Math.max(Math.round(2f * k), numCands), indexSearcher);
124124
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
125125
List<LeafReaderContext> leafReaderContexts = reader.leaves();
126+
127+
int totalVectors = 0;
128+
for (LeafReaderContext leafReaderContext : leafReaderContexts) {
129+
LeafReader leafReader = leafReaderContext.reader();
130+
FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(field);
131+
if (floatVectorValues != null) {
132+
totalVectors += floatVectorValues.size();
133+
}
134+
}
135+
136+
final float visitRatio;
137+
if (providedVisitRatio == 0.0f) {
138+
// dynamically set the percentage
139+
float expected = (float) Math.round(1.75f * Math.log10(numCands) * Math.log10(numCands) * (numCands));
140+
visitRatio = expected / totalVectors;
141+
} else {
142+
visitRatio = providedVisitRatio;
143+
}
144+
126145
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
127146
for (LeafReaderContext context : leafReaderContexts) {
128-
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
147+
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, visitRatio));
129148
}
130149
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
131150

@@ -138,8 +157,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
138157
return new KnnScoreDocQuery(topK.scoreDocs, reader);
139158
}
140159

141-
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
142-
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
160+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio)
161+
throws IOException {
162+
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio);
143163
IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3);
144164
int deduplicateCount = 0;
145165
for (ScoreDoc scoreDoc : results.scoreDocs) {
@@ -159,12 +179,13 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollec
159179
return new TopDocs(results.totalHits, deduplicatedScoreDocs);
160180
}
161181

162-
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
182+
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager, float visitRatio)
183+
throws IOException {
163184
final LeafReader reader = ctx.reader();
164185
final Bits liveDocs = reader.getLiveDocs();
165186

166187
if (filterWeight == null) {
167-
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
188+
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio);
168189
}
169190

170191
Scorer scorer = filterWeight.scorer(ctx);
@@ -174,14 +195,15 @@ TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorM
174195

175196
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
176197
final int cost = acceptDocs.cardinality();
177-
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
198+
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio);
178199
}
179200

180201
abstract TopDocs approximateSearch(
181202
LeafReaderContext context,
182203
Bits acceptDocs,
183204
int visitedLimit,
184-
KnnCollectorManager knnCollectorManager
205+
KnnCollectorManager knnCollectorManager,
206+
float visitRatio
185207
) throws IOException;
186208

187209
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {

server/src/main/java/org/elasticsearch/search/vectors/IVFKnnFloatVectorQuery.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ protected TopDocs approximateSearch(
7878
LeafReaderContext context,
7979
Bits acceptDocs,
8080
int visitedLimit,
81-
KnnCollectorManager knnCollectorManager
81+
KnnCollectorManager knnCollectorManager,
82+
float visitRatio
8283
) throws IOException {
8384
LeafReader reader = context.reader();
8485
FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field);
@@ -89,13 +90,7 @@ protected TopDocs approximateSearch(
8990
if (floatVectorValues.size() == 0) {
9091
return NO_RESULTS;
9192
}
92-
KnnSearchStrategy strategy = searchStrategy;
93-
if (searchStrategy.getVisitRatio() == 0.0f) {
94-
// dynamically set the percentage
95-
float expected = (float) Math.round(1.75f * Math.log10(numCands) * Math.log10(numCands) * (numCands));
96-
float ratio = expected / floatVectorValues.size();
97-
strategy = new IVFKnnSearchStrategy(ratio);
98-
}
93+
KnnSearchStrategy strategy = new IVFKnnSearchStrategy(visitRatio);
9994
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, strategy, context);
10095
if (knnCollector == null) {
10196
return NO_RESULTS;

0 commit comments

Comments
 (0)