Skip to content

Commit 3faf4ce

Browse files
authored
Use to VectorScorer for exact vector scoring (#109945)
Lucene 9.11 introduced a new VectorScorer interface. We should utilize this interface when scoring exact vectors. related to: #109293
1 parent c709b78 commit 3faf4ce

File tree

7 files changed

+611
-80
lines changed

7 files changed

+611
-80
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@
2828
import org.apache.lucene.index.SegmentWriteState;
2929
import org.apache.lucene.index.VectorEncoding;
3030
import org.apache.lucene.index.VectorSimilarityFunction;
31-
import org.apache.lucene.queries.function.FunctionQuery;
32-
import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
33-
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
34-
import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
35-
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
36-
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
37-
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
38-
import org.apache.lucene.search.BooleanClause;
39-
import org.apache.lucene.search.BooleanQuery;
4031
import org.apache.lucene.search.FieldExistsQuery;
4132
import org.apache.lucene.search.Query;
4233
import org.apache.lucene.search.join.BitSetProducer;
@@ -67,6 +58,7 @@
6758
import org.elasticsearch.index.query.SearchExecutionContext;
6859
import org.elasticsearch.search.DocValueFormat;
6960
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
61+
import org.elasticsearch.search.vectors.DenseVectorQuery;
7062
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
7163
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
7264
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
@@ -1484,19 +1476,7 @@ private Query createExactKnnByteQuery(byte[] queryVector) {
14841476
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
14851477
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
14861478
}
1487-
VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
1488-
return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
1489-
.add(
1490-
new FunctionQuery(
1491-
new ByteVectorSimilarityFunction(
1492-
vectorSimilarityFunction,
1493-
new ByteKnnVectorFieldSource(name()),
1494-
new ConstKnnByteVectorValueSource(queryVector)
1495-
)
1496-
),
1497-
BooleanClause.Occur.SHOULD
1498-
)
1499-
.build();
1479+
return new DenseVectorQuery.Bytes(queryVector, name());
15001480
}
15011481

15021482
private Query createExactKnnFloatQuery(float[] queryVector) {
@@ -1519,19 +1499,7 @@ && isNotUnitVector(squaredMagnitude)) {
15191499
}
15201500
}
15211501
}
1522-
VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
1523-
return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
1524-
.add(
1525-
new FunctionQuery(
1526-
new FloatVectorSimilarityFunction(
1527-
vectorSimilarityFunction,
1528-
new FloatKnnVectorFieldSource(name()),
1529-
new ConstKnnFloatValueSource(queryVector)
1530-
)
1531-
),
1532-
BooleanClause.Occur.SHOULD
1533-
)
1534-
.build();
1502+
return new DenseVectorQuery.Floats(queryVector, name());
15351503
}
15361504

15371505
Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) {
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.search.vectors;
10+
11+
import org.apache.lucene.index.ByteVectorValues;
12+
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.index.LeafReaderContext;
14+
import org.apache.lucene.search.DocIdSetIterator;
15+
import org.apache.lucene.search.Explanation;
16+
import org.apache.lucene.search.IndexSearcher;
17+
import org.apache.lucene.search.Query;
18+
import org.apache.lucene.search.QueryVisitor;
19+
import org.apache.lucene.search.ScoreMode;
20+
import org.apache.lucene.search.Scorer;
21+
import org.apache.lucene.search.VectorScorer;
22+
import org.apache.lucene.search.Weight;
23+
24+
import java.io.IOException;
25+
import java.util.Arrays;
26+
import java.util.Objects;
27+
28+
/**
29+
* Exact knn query. Will iterate and score all documents that have the provided dense vector field in the index.
30+
*/
31+
public abstract class DenseVectorQuery extends Query {
32+
33+
protected final String field;
34+
35+
public DenseVectorQuery(String field) {
36+
this.field = field;
37+
}
38+
39+
@Override
40+
public void visit(QueryVisitor queryVisitor) {
41+
queryVisitor.visitLeaf(this);
42+
}
43+
44+
abstract static class DenseVectorWeight extends Weight {
45+
private final String field;
46+
private final float boost;
47+
48+
protected DenseVectorWeight(DenseVectorQuery query, float boost) {
49+
super(query);
50+
this.field = query.field;
51+
this.boost = boost;
52+
}
53+
54+
abstract VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException;
55+
56+
@Override
57+
public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException {
58+
VectorScorer vectorScorer = vectorScorer(leafReaderContext);
59+
if (vectorScorer == null) {
60+
return Explanation.noMatch("No vector values found for field: " + field);
61+
}
62+
DocIdSetIterator iterator = vectorScorer.iterator();
63+
iterator.advance(i);
64+
if (iterator.docID() == i) {
65+
float score = vectorScorer.score();
66+
return Explanation.match(vectorScorer.score() * boost, "found vector with calculated similarity: " + score);
67+
}
68+
return Explanation.noMatch("Document not found in vector values for field: " + field);
69+
}
70+
71+
@Override
72+
public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
73+
VectorScorer vectorScorer = vectorScorer(leafReaderContext);
74+
if (vectorScorer == null) {
75+
return null;
76+
}
77+
return new DenseVectorScorer(this, vectorScorer);
78+
}
79+
80+
@Override
81+
public boolean isCacheable(LeafReaderContext leafReaderContext) {
82+
return true;
83+
}
84+
}
85+
86+
public static class Floats extends DenseVectorQuery {
87+
88+
private final float[] query;
89+
90+
public Floats(float[] query, String field) {
91+
super(field);
92+
this.query = query;
93+
}
94+
95+
public float[] getQuery() {
96+
return query;
97+
}
98+
99+
@Override
100+
public String toString(String field) {
101+
return "DenseVectorQuery.Floats";
102+
}
103+
104+
@Override
105+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
106+
return new DenseVectorWeight(Floats.this, boost) {
107+
@Override
108+
VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException {
109+
FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(field);
110+
if (vectorValues == null) {
111+
return null;
112+
}
113+
return vectorValues.scorer(query);
114+
}
115+
};
116+
}
117+
118+
@Override
119+
public boolean equals(Object o) {
120+
if (this == o) return true;
121+
if (o == null || getClass() != o.getClass()) return false;
122+
Floats floats = (Floats) o;
123+
return Objects.equals(field, floats.field) && Objects.deepEquals(query, floats.query);
124+
}
125+
126+
@Override
127+
public int hashCode() {
128+
return Objects.hash(field, Arrays.hashCode(query));
129+
}
130+
}
131+
132+
public static class Bytes extends DenseVectorQuery {
133+
134+
private final byte[] query;
135+
136+
public Bytes(byte[] query, String field) {
137+
super(field);
138+
this.query = query;
139+
}
140+
141+
@Override
142+
public String toString(String field) {
143+
return "DenseVectorQuery.Bytes";
144+
}
145+
146+
@Override
147+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
148+
return new DenseVectorWeight(Bytes.this, boost) {
149+
@Override
150+
VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException {
151+
ByteVectorValues vectorValues = leafReaderContext.reader().getByteVectorValues(field);
152+
if (vectorValues == null) {
153+
return null;
154+
}
155+
return vectorValues.scorer(query);
156+
}
157+
};
158+
}
159+
160+
@Override
161+
public boolean equals(Object o) {
162+
if (this == o) return true;
163+
if (o == null || getClass() != o.getClass()) return false;
164+
Bytes bytes = (Bytes) o;
165+
return Objects.equals(field, bytes.field) && Objects.deepEquals(query, bytes.query);
166+
}
167+
168+
@Override
169+
public int hashCode() {
170+
return Objects.hash(field, Arrays.hashCode(query));
171+
}
172+
}
173+
174+
static class DenseVectorScorer extends Scorer {
175+
176+
private final VectorScorer vectorScorer;
177+
private final DocIdSetIterator iterator;
178+
private final float boost;
179+
180+
DenseVectorScorer(DenseVectorWeight weight, VectorScorer vectorScorer) {
181+
super(weight);
182+
this.vectorScorer = vectorScorer;
183+
this.iterator = vectorScorer.iterator();
184+
this.boost = weight.boost;
185+
}
186+
187+
@Override
188+
public DocIdSetIterator iterator() {
189+
return vectorScorer.iterator();
190+
}
191+
192+
@Override
193+
public float getMaxScore(int i) throws IOException {
194+
// TODO: can we optimize this at all?
195+
return Float.POSITIVE_INFINITY;
196+
}
197+
198+
@Override
199+
public float score() throws IOException {
200+
assert iterator.docID() != -1;
201+
return vectorScorer.score() * boost;
202+
}
203+
204+
@Override
205+
public int docID() {
206+
return iterator.docID();
207+
}
208+
}
209+
}

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
package org.elasticsearch.index.mapper.vectors;
1010

11-
import org.apache.lucene.queries.function.FunctionQuery;
12-
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
13-
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
14-
import org.apache.lucene.search.BooleanClause;
15-
import org.apache.lucene.search.BooleanQuery;
1611
import org.apache.lucene.search.KnnByteVectorQuery;
1712
import org.apache.lucene.search.KnnFloatVectorQuery;
1813
import org.apache.lucene.search.Query;
@@ -25,6 +20,7 @@
2520
import org.elasticsearch.index.mapper.MappedFieldType;
2621
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
2722
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
23+
import org.elasticsearch.search.vectors.DenseVectorQuery;
2824
import org.elasticsearch.search.vectors.VectorData;
2925

3026
import java.io.IOException;
@@ -218,16 +214,7 @@ public void testExactKnnQuery() {
218214
queryVector[i] = randomFloat();
219215
}
220216
Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector));
221-
assertTrue(query instanceof BooleanQuery);
222-
BooleanQuery booleanQuery = (BooleanQuery) query;
223-
boolean foundFunction = false;
224-
for (BooleanClause clause : booleanQuery) {
225-
if (clause.getQuery() instanceof FunctionQuery functionQuery) {
226-
foundFunction = true;
227-
assertTrue(functionQuery.getValueSource() instanceof FloatVectorSimilarityFunction);
228-
}
229-
}
230-
assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction);
217+
assertTrue(query instanceof DenseVectorQuery.Floats);
231218
}
232219
{
233220
DenseVectorFieldType field = new DenseVectorFieldType(
@@ -245,16 +232,7 @@ public void testExactKnnQuery() {
245232
queryVector[i] = randomByte();
246233
}
247234
Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector));
248-
assertTrue(query instanceof BooleanQuery);
249-
BooleanQuery booleanQuery = (BooleanQuery) query;
250-
boolean foundFunction = false;
251-
for (BooleanClause clause : booleanQuery) {
252-
if (clause.getQuery() instanceof FunctionQuery functionQuery) {
253-
foundFunction = true;
254-
assertTrue(functionQuery.getValueSource() instanceof ByteVectorSimilarityFunction);
255-
}
256-
}
257-
assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction);
235+
assertTrue(query instanceof DenseVectorQuery.Bytes);
258236
}
259237
}
260238

0 commit comments

Comments
 (0)