Skip to content

Commit b67637a

Browse files
committed
Change abstraction to wrap around KNN query
1 parent 8d88cab commit b67637a

File tree

5 files changed

+122
-149
lines changed

5 files changed

+122
-149
lines changed

lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
9999
if (topK.scoreDocs.length == 0) {
100100
return new MatchNoDocsQuery();
101101
}
102-
return createRewrittenQuery(reader, topK);
102+
return createRewrittenQuery(reader, topK.scoreDocs);
103103
}
104104

105105
private TopDocs searchLeaf(
@@ -255,18 +255,18 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
255255
return TopDocs.merge(k, perLeafResults);
256256
}
257257

258-
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
259-
int len = topK.scoreDocs.length;
258+
static Query createRewrittenQuery(IndexReader reader, ScoreDoc[] scoreDocs) {
259+
int len = scoreDocs.length;
260260

261261
assert len > 0;
262-
float maxScore = topK.scoreDocs[0].score;
262+
float maxScore = scoreDocs[0].score;
263263

264-
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
264+
Arrays.sort(scoreDocs, Comparator.comparingInt(a -> a.doc));
265265
int[] docs = new int[len];
266266
float[] scores = new float[len];
267267
for (int i = 0; i < len; i++) {
268-
docs[i] = topK.scoreDocs[i].doc;
269-
scores[i] = topK.scoreDocs[i].score;
268+
docs[i] = scoreDocs[i].doc;
269+
scores[i] = scoreDocs[i].score;
270270
}
271271
int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
272272
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.search;
18+
19+
import static org.apache.lucene.search.AbstractKnnVectorQuery.createRewrittenQuery;
20+
21+
import java.io.IOException;
22+
import java.util.Arrays;
23+
import java.util.Objects;
24+
import org.apache.lucene.index.FieldInfo;
25+
import org.apache.lucene.index.FloatVectorValues;
26+
import org.apache.lucene.index.IndexReader;
27+
import org.apache.lucene.index.VectorSimilarityFunction;
28+
29+
/**
30+
* A wrapper of KnnFloatVectorQuery which does full-precision reranking.
31+
*
32+
* @lucene.experimental
33+
*/
34+
public class RerankKnnFloatVectorQuery extends Query {
35+
36+
private final int k;
37+
private final float[] target;
38+
private final KnnFloatVectorQuery query;
39+
40+
/**
41+
* Execute the KnnFloatVectorQuery and re-rank using full-precision vectors
42+
*
43+
* @param query the KNN query to execute as initial phase
44+
* @param target the target of the search
45+
* @param k the number of documents to find
46+
* @throws IllegalArgumentException if <code>k</code> is less than 1
47+
*/
48+
public RerankKnnFloatVectorQuery(KnnFloatVectorQuery query, float[] target, int k) {
49+
this.query = query;
50+
this.target = target;
51+
this.k = k;
52+
}
53+
54+
@Override
55+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
56+
IndexReader reader = indexSearcher.getIndexReader();
57+
Query rewritten = indexSearcher.rewrite(query);
58+
Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
59+
HitQueue queue = new HitQueue(k, false);
60+
for (var leaf : reader.leaves()) {
61+
Scorer scorer = weight.scorer(leaf);
62+
if (scorer == null) {
63+
continue;
64+
}
65+
FloatVectorValues floatVectorValues = leaf.reader().getFloatVectorValues(query.getField());
66+
if (floatVectorValues == null) {
67+
continue;
68+
}
69+
FieldInfo fi = leaf.reader().getFieldInfos().fieldInfo(query.getField());
70+
VectorSimilarityFunction comparer = fi.getVectorSimilarityFunction();
71+
DocIdSetIterator iterator = scorer.iterator();
72+
while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
73+
int docId = iterator.docID();
74+
float[] vectorValue = floatVectorValues.vectorValue(docId);
75+
float score = comparer.compare(vectorValue, target);
76+
queue.insertWithOverflow(new ScoreDoc(docId, score));
77+
}
78+
}
79+
int i = 0;
80+
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
81+
for (ScoreDoc topDoc : queue) {
82+
scoreDocs[i++] = topDoc;
83+
}
84+
return createRewrittenQuery(reader, scoreDocs);
85+
}
86+
87+
@Override
88+
public int hashCode() {
89+
int result = Arrays.hashCode(target);
90+
result = 31 * result + Objects.hash(query, k);
91+
return result;
92+
}
93+
94+
@Override
95+
public boolean equals(Object o) {
96+
if (this == o) return true;
97+
RerankKnnFloatVectorQuery that = (RerankKnnFloatVectorQuery) o;
98+
return Objects.equals(query, that.query) && k == that.k;
99+
}
100+
101+
@Override
102+
public void visit(QueryVisitor visitor) {
103+
query.visit(visitor);
104+
}
105+
106+
@Override
107+
public String toString(String field) {
108+
return getClass().getSimpleName() + ":" + query.toString(field) + "[" + k + "]";
109+
}
110+
}

lucene/core/src/java/org/apache/lucene/search/TwoPhaseKnnVectorQuery.java

Lines changed: 0 additions & 138 deletions
This file was deleted.

lucene/core/src/test/META-INF/services/org.apache.lucene.codecs.Codec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515

1616
org.apache.lucene.codecs.TestMinimalCodec$MinimalCodec
1717
org.apache.lucene.codecs.TestMinimalCodec$MinimalCompoundCodec
18-
org.apache.lucene.search.TestTwoPhaseKnnVectorQuery$QuantizedCodec
18+
org.apache.lucene.search.TestRerankKnnFloatVectorQuery$QuantizedCodec

lucene/core/src/test/org/apache/lucene/search/TestTwoPhaseKnnVectorQuery.java renamed to lucene/core/src/test/org/apache/lucene/search/TestRerankKnnFloatVectorQuery.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import org.junit.Before;
4040
import org.junit.Test;
4141

42-
public class TestTwoPhaseKnnVectorQuery extends LuceneTestCase {
42+
public class TestRerankKnnFloatVectorQuery extends LuceneTestCase {
4343

4444
private static final String FIELD = "vector";
4545
public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION =
@@ -85,8 +85,9 @@ public void testTwoPhaseKnnVectorQuery() throws Exception {
8585
int k = 10;
8686
double oversample = 1.0;
8787

88-
TwoPhaseKnnVectorQuery query =
89-
new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null);
88+
KnnFloatVectorQuery knnQuery =
89+
new KnnFloatVectorQuery(FIELD, targetVector, k + (int) (k * oversample));
90+
RerankKnnFloatVectorQuery query = new RerankKnnFloatVectorQuery(knnQuery, targetVector, k);
9091
TopDocs topDocs = searcher.search(query, k);
9192

9293
// Step 3: Verify that TopDocs scores match similarity with unquantized vectors

0 commit comments

Comments
 (0)