Skip to content

Commit df06716

Browse files
committed
Use a FunctionScoreQuery to replace scores using a VectorSimilarity based DoubleValueSource
1 parent 64c362b commit df06716

File tree

7 files changed

+388
-33
lines changed

7 files changed

+388
-33
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ static TransportVersion def(int id) {
190190
public static final TransportVersion LOGSDB_TELEMETRY_STATS = def(8_785_00_0);
191191
public static final TransportVersion KQL_QUERY_ADDED = def(8_786_00_0);
192192
public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0);
193+
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_788_00_0);
193194

194195
/*
195196
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.lucene.index.SegmentWriteState;
3131
import org.apache.lucene.index.VectorEncoding;
3232
import org.apache.lucene.index.VectorSimilarityFunction;
33+
import org.apache.lucene.queries.function.FunctionScoreQuery;
3334
import org.apache.lucene.search.FieldExistsQuery;
3435
import org.apache.lucene.search.Query;
3536
import org.apache.lucene.search.join.BitSetProducer;
@@ -121,6 +122,8 @@ public static boolean isNotUnitVector(float magnitude) {
121122
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
122123
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
123124

125+
public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates
126+
124127
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
125128
public static final int MAGNITUDE_BYTES = 4;
126129

@@ -2000,6 +2003,7 @@ public Query createKnnQuery(
20002003
VectorData queryVector,
20012004
Integer k,
20022005
int numCands,
2006+
Float rescoreOversample,
20032007
Query filter,
20042008
Float similarityThreshold,
20052009
BitSetProducer parentFilter
@@ -2010,21 +2014,50 @@ public Query createKnnQuery(
20102014
);
20112015
}
20122016
return switch (getElementType()) {
2013-
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
2014-
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter);
2015-
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
2017+
case BYTE -> createKnnByteQuery(
2018+
queryVector.asByteVector(),
2019+
k,
2020+
numCands,
2021+
filter,
2022+
rescoreOversample,
2023+
similarityThreshold,
2024+
parentFilter
2025+
);
2026+
case FLOAT -> createKnnFloatQuery(
2027+
queryVector.asFloatVector(),
2028+
k,
2029+
numCands,
2030+
rescoreOversample,
2031+
filter,
2032+
similarityThreshold,
2033+
parentFilter
2034+
);
2035+
case BIT -> createKnnBitQuery(
2036+
queryVector.asByteVector(),
2037+
k,
2038+
numCands,
2039+
rescoreOversample,
2040+
filter,
2041+
similarityThreshold,
2042+
parentFilter
2043+
);
20162044
};
20172045
}
20182046

20192047
private Query createKnnBitQuery(
20202048
byte[] queryVector,
20212049
Integer k,
20222050
int numCands,
2051+
Float rescoreOversample,
20232052
Query filter,
20242053
Float similarityThreshold,
20252054
BitSetProducer parentFilter
20262055
) {
20272056
elementType.checkDimensions(dims, queryVector.length);
2057+
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
2058+
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
2059+
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
2060+
}
20282061
Query knnQuery = parentFilter != null
20292062
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
20302063
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
@@ -2035,6 +2068,17 @@ private Query createKnnBitQuery(
20352068
similarity.score(similarityThreshold, elementType, dims)
20362069
);
20372070
}
2071+
if (rescoreOversample != null) {
2072+
knnQuery = new FunctionScoreQuery(
2073+
knnQuery,
2074+
new VectorSimilarityByteValueSource(
2075+
name(),
2076+
queryVector,
2077+
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
2078+
)
2079+
);
2080+
2081+
}
20382082
return knnQuery;
20392083
}
20402084

@@ -2043,6 +2087,7 @@ private Query createKnnByteQuery(
20432087
Integer k,
20442088
int numCands,
20452089
Query filter,
2090+
Float rescoreOversample,
20462091
Float similarityThreshold,
20472092
BitSetProducer parentFilter
20482093
) {
@@ -2052,23 +2097,38 @@ private Query createKnnByteQuery(
20522097
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
20532098
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
20542099
}
2100+
int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2101+
int adjustedNumCands = Math.max(adjustedK, numCands);
2102+
20552103
Query knnQuery = parentFilter != null
2056-
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
2057-
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
2104+
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
2105+
: new ESKnnByteVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
20582106
if (similarityThreshold != null) {
20592107
knnQuery = new VectorSimilarityQuery(
20602108
knnQuery,
20612109
similarityThreshold,
20622110
similarity.score(similarityThreshold, elementType, dims)
20632111
);
20642112
}
2113+
if (rescoreOversample != null) {
2114+
knnQuery = new FunctionScoreQuery(
2115+
knnQuery,
2116+
new VectorSimilarityByteValueSource(
2117+
name(),
2118+
queryVector,
2119+
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
2120+
)
2121+
);
2122+
2123+
}
20652124
return knnQuery;
20662125
}
20672126

20682127
private Query createKnnFloatQuery(
20692128
float[] queryVector,
20702129
Integer k,
20712130
int numCands,
2131+
Float rescoreOversample,
20722132
Query filter,
20732133
Float similarityThreshold,
20742134
BitSetProducer parentFilter
@@ -2088,16 +2148,30 @@ && isNotUnitVector(squaredMagnitude)) {
20882148
}
20892149
}
20902150
}
2151+
2152+
int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2153+
int adjustedNumCands = Math.max(adjustedK, numCands);
20912154
Query knnQuery = parentFilter != null
2092-
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
2093-
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter);
2155+
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
2156+
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
20942157
if (similarityThreshold != null) {
20952158
knnQuery = new VectorSimilarityQuery(
20962159
knnQuery,
20972160
similarityThreshold,
20982161
similarity.score(similarityThreshold, elementType, dims)
20992162
);
21002163
}
2164+
if (rescoreOversample != null) {
2165+
knnQuery = new FunctionScoreQuery(
2166+
knnQuery,
2167+
new VectorSimilarityFloatValueSource(
2168+
name(),
2169+
queryVector,
2170+
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT)
2171+
)
2172+
);
2173+
2174+
}
21012175
return knnQuery;
21022176
}
21032177

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.mapper.vectors;
11+
12+
import org.apache.lucene.index.ByteVectorValues;
13+
import org.apache.lucene.index.KnnVectorValues;
14+
import org.apache.lucene.index.LeafReader;
15+
import org.apache.lucene.index.LeafReaderContext;
16+
import org.apache.lucene.index.VectorSimilarityFunction;
17+
import org.apache.lucene.search.DocIdSetIterator;
18+
import org.apache.lucene.search.DoubleValues;
19+
import org.apache.lucene.search.DoubleValuesSource;
20+
import org.apache.lucene.search.IndexSearcher;
21+
22+
import java.io.IOException;
23+
import java.util.Arrays;
24+
import java.util.Objects;
25+
26+
public class VectorSimilarityByteValueSource extends DoubleValuesSource {
27+
28+
private final String field;
29+
private final byte[] target;
30+
private final VectorSimilarityFunction vectorSimilarityFunction;
31+
32+
public VectorSimilarityByteValueSource(String field, byte[] target, VectorSimilarityFunction vectorSimilarityFunction) {
33+
this.field = field;
34+
this.target = target;
35+
this.vectorSimilarityFunction = vectorSimilarityFunction;
36+
}
37+
38+
@Override
39+
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
40+
final LeafReader reader = ctx.reader();
41+
42+
ByteVectorValues vectorValues = reader.getByteVectorValues(field);
43+
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
44+
45+
return new DoubleValues() {
46+
private int docId = -1;
47+
48+
@Override
49+
public double doubleValue() throws IOException {
50+
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId));
51+
}
52+
53+
@Override
54+
public boolean advanceExact(int doc) throws IOException {
55+
docId = doc;
56+
return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS;
57+
}
58+
};
59+
}
60+
61+
@Override
62+
public boolean needsScores() {
63+
return false;
64+
}
65+
66+
@Override
67+
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
68+
return this;
69+
}
70+
71+
@Override
72+
public int hashCode() {
73+
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
74+
}
75+
76+
@Override
77+
public boolean equals(Object o) {
78+
if (this == o) return true;
79+
if (o == null || getClass() != o.getClass()) return false;
80+
VectorSimilarityByteValueSource that = (VectorSimilarityByteValueSource) o;
81+
return Objects.equals(field, that.field)
82+
&& Objects.deepEquals(target, that.target)
83+
&& vectorSimilarityFunction == that.vectorSimilarityFunction;
84+
}
85+
86+
@Override
87+
public String toString() {
88+
return "VectorSimilarityByteValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")";
89+
}
90+
91+
@Override
92+
public boolean isCacheable(LeafReaderContext ctx) {
93+
return false;
94+
}
95+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.mapper.vectors;
11+
12+
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.index.KnnVectorValues;
14+
import org.apache.lucene.index.LeafReader;
15+
import org.apache.lucene.index.LeafReaderContext;
16+
import org.apache.lucene.index.VectorSimilarityFunction;
17+
import org.apache.lucene.search.DocIdSetIterator;
18+
import org.apache.lucene.search.DoubleValues;
19+
import org.apache.lucene.search.DoubleValuesSource;
20+
import org.apache.lucene.search.IndexSearcher;
21+
22+
import java.io.IOException;
23+
import java.util.Arrays;
24+
import java.util.Objects;
25+
26+
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
27+
28+
private final String field;
29+
private final float[] target;
30+
private final VectorSimilarityFunction vectorSimilarityFunction;
31+
32+
public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
33+
this.field = field;
34+
this.target = target;
35+
this.vectorSimilarityFunction = vectorSimilarityFunction;
36+
}
37+
38+
@Override
39+
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
40+
final LeafReader reader = ctx.reader();
41+
42+
FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
43+
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
44+
45+
return new DoubleValues() {
46+
private int docId = -1;
47+
48+
@Override
49+
public double doubleValue() throws IOException {
50+
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId));
51+
}
52+
53+
@Override
54+
public boolean advanceExact(int doc) throws IOException {
55+
docId = doc;
56+
return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS;
57+
}
58+
};
59+
}
60+
61+
@Override
62+
public boolean needsScores() {
63+
return false;
64+
}
65+
66+
@Override
67+
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
68+
return this;
69+
}
70+
71+
@Override
72+
public int hashCode() {
73+
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
74+
}
75+
76+
@Override
77+
public boolean equals(Object o) {
78+
if (this == o) return true;
79+
if (o == null || getClass() != o.getClass()) return false;
80+
VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o;
81+
return Objects.equals(field, that.field)
82+
&& Objects.deepEquals(target, that.target)
83+
&& vectorSimilarityFunction == that.vectorSimilarityFunction;
84+
}
85+
86+
@Override
87+
public String toString() {
88+
return "VectorSimilarityFloatValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")";
89+
}
90+
91+
@Override
92+
public boolean isCacheable(LeafReaderContext ctx) {
93+
return false;
94+
}
95+
}

0 commit comments

Comments
 (0)