Skip to content

Commit bc1e5c6

Browse files
committed
Minor refactoring to reuse KnnScoreDocQuery
1 parent ff2c1e9 commit bc1e5c6

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import org.apache.lucene.index.IndexReader;
1213
import org.apache.lucene.index.LeafReaderContext;
1314
import org.apache.lucene.search.DocIdSetIterator;
1415
import org.apache.lucene.search.Explanation;
@@ -37,26 +38,45 @@
3738
public class KnnScoreDocQuery extends Query {
3839
private final int[] docs;
3940
private final float[] scores;
41+
42+
// the indexes in docs and scores corresponding to the first matching document in each segment.
43+
// If a segment has no matching documents, it should be assigned the index of the next segment that does.
44+
// There should be a final entry that is always docs.length-1.
4045
private final int[] segmentStarts;
46+
// an object identifying the reader context that was used to build this query
47+
4148
private final Object contextIdentity;
4249

4350
/**
4451
* Creates a query.
4552
*
4653
* @param docs the global doc IDs of documents that match, in ascending order
4754
* @param scores the scores of the matching documents
48-
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
49-
* document in each segment. If a segment has no matching documents, it should be assigned
50-
* the index of the next segment that does. There should be a final entry that is always
51-
* docs.length-1.
52-
* @param contextIdentity an object identifying the reader context that was used to build this
53-
* query
55+
* @param reader IndexReader
5456
*/
55-
KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
57+
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
5658
this.docs = docs;
5759
this.scores = scores;
58-
this.segmentStarts = segmentStarts;
59-
this.contextIdentity = contextIdentity;
60+
this.segmentStarts = findSegmentStarts(reader, docs);
61+
this.contextIdentity = reader.getContext().id();
62+
}
63+
64+
private static int[] findSegmentStarts(IndexReader reader, int[] docs) {
65+
int[] starts = new int[reader.leaves().size() + 1];
66+
starts[starts.length - 1] = docs.length;
67+
if (starts.length == 2) {
68+
return starts;
69+
}
70+
int resultIndex = 0;
71+
for (int i = 1; i < starts.length - 1; i++) {
72+
int upper = reader.leaves().get(i).docBase;
73+
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
74+
if (resultIndex < 0) {
75+
resultIndex = -1 - resultIndex;
76+
}
77+
starts[i] = resultIndex;
78+
}
79+
return starts;
6080
}
6181

6282
@Override

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12-
import org.apache.lucene.index.IndexReader;
1312
import org.apache.lucene.search.Query;
1413
import org.apache.lucene.search.ScoreDoc;
1514
import org.elasticsearch.TransportVersion;
@@ -25,7 +24,6 @@
2524
import org.elasticsearch.xcontent.XContentBuilder;
2625

2726
import java.io.IOException;
28-
import java.util.Arrays;
2927
import java.util.Objects;
3028

3129
/**
@@ -153,9 +151,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
153151
scores[i] = scoreDocs[i].score;
154152
}
155153

156-
IndexReader reader = context.getIndexReader();
157-
int[] segmentStarts = findSegmentStarts(reader, docs);
158-
return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id());
154+
return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
159155
}
160156

161157
@Override
@@ -169,24 +165,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
169165
return super.doRewrite(queryRewriteContext);
170166
}
171167

172-
private static int[] findSegmentStarts(IndexReader reader, int[] docs) {
173-
int[] starts = new int[reader.leaves().size() + 1];
174-
starts[starts.length - 1] = docs.length;
175-
if (starts.length == 2) {
176-
return starts;
177-
}
178-
int resultIndex = 0;
179-
for (int i = 1; i < starts.length - 1; i++) {
180-
int upper = reader.leaves().get(i).docBase;
181-
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
182-
if (resultIndex < 0) {
183-
resultIndex = -1 - resultIndex;
184-
}
185-
starts[i] = resultIndex;
186-
}
187-
return starts;
188-
}
189-
190168
@Override
191169
protected boolean doEquals(KnnScoreDocQueryBuilder other) {
192170
if (scoreDocs.length != other.scoreDocs.length) {

0 commit comments

Comments
 (0)