Skip to content

Commit 4497b92

Browse files
committed
Add vector rescore builder to kNN retriever
1 parent 2a9e300 commit 4497b92

File tree

7 files changed

+50
-19
lines changed

7 files changed

+50
-19
lines changed

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
2121
import org.elasticsearch.search.vectors.KnnSearchBuilder;
2222
import org.elasticsearch.search.vectors.QueryVectorBuilder;
23+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
2324
import org.elasticsearch.search.vectors.VectorData;
2425
import org.elasticsearch.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.xcontent.ObjectParser;
2527
import org.elasticsearch.xcontent.ParseField;
2628
import org.elasticsearch.xcontent.XContentBuilder;
2729
import org.elasticsearch.xcontent.XContentParser;
@@ -52,6 +54,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
5254
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
5355
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
5456
public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity");
57+
public static final ParseField RESCORE_FIELD = new ParseField("rescore");
5558

5659
@SuppressWarnings("unchecked")
5760
public static final ConstructingObjectParser<KnnRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
@@ -73,7 +76,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
7376
(QueryVectorBuilder) args[2],
7477
(int) args[3],
7578
(int) args[4],
76-
(Float) args[5]
79+
(RescoreVectorBuilder) args[6], (Float) args[5]
7780
);
7881
}
7982
);
@@ -89,6 +92,12 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
8992
PARSER.declareInt(constructorArg(), K_FIELD);
9093
PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD);
9194
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY);
95+
PARSER.declareField(
96+
optionalConstructorArg(),
97+
(p, c) -> RescoreVectorBuilder.fromXContent(p),
98+
RESCORE_FIELD,
99+
ObjectParser.ValueType.OBJECT_OR_NULL
100+
);
92101
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
93102
}
94103

@@ -104,6 +113,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
104113
private final QueryVectorBuilder queryVectorBuilder;
105114
private final int k;
106115
private final int numCands;
116+
private final RescoreVectorBuilder rescoreVectorBuilder;
107117
private final Float similarity;
108118

109119
public KnnRetrieverBuilder(
@@ -112,6 +122,7 @@ public KnnRetrieverBuilder(
112122
QueryVectorBuilder queryVectorBuilder,
113123
int k,
114124
int numCands,
125+
RescoreVectorBuilder rescoreVectorBuilder,
115126
Float similarity
116127
) {
117128
if (queryVector == null && queryVectorBuilder == null) {
@@ -137,6 +148,7 @@ public KnnRetrieverBuilder(
137148
this.k = k;
138149
this.numCands = numCands;
139150
this.similarity = similarity;
151+
this.rescoreVectorBuilder = rescoreVectorBuilder;
140152
}
141153

142154
private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
@@ -148,6 +160,7 @@ private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVe
148160
this.similarity = clone.similarity;
149161
this.retrieverName = clone.retrieverName;
150162
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
163+
this.rescoreVectorBuilder = clone.rescoreVectorBuilder;
151164
}
152165

153166
@Override
@@ -229,6 +242,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
229242
null,
230243
k,
231244
numCands,
245+
rescoreVectorBuilder,
232246
similarity
233247
);
234248
if (preFilterQueryBuilders != null) {
@@ -261,6 +275,10 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
261275
if (similarity != null) {
262276
builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity);
263277
}
278+
279+
if (rescoreVectorBuilder != null) {
280+
builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder);
281+
}
264282
}
265283

266284
@Override
@@ -272,12 +290,13 @@ public boolean doEquals(Object o) {
272290
&& ((queryVector == null && that.queryVector == null)
273291
|| (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get())))
274292
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
275-
&& Objects.equals(similarity, that.similarity);
293+
&& Objects.equals(similarity, that.similarity)
294+
&& Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder);
276295
}
277296

278297
@Override
279298
public int doHashCode() {
280-
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
299+
int result = Objects.hash(field, queryVectorBuilder, k, numCands, rescoreVectorBuilder, similarity);
281300
result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null);
282301
return result;
283302
}

server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.search.SearchModule;
2323
import org.elasticsearch.search.builder.SearchSourceBuilder;
2424
import org.elasticsearch.search.rank.RankDoc;
25+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
2526
import org.elasticsearch.test.AbstractXContentTestCase;
2627
import org.elasticsearch.usage.SearchUsage;
2728
import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -51,8 +52,18 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() {
5152
int k = randomIntBetween(1, 100);
5253
int numCands = randomIntBetween(k + 20, 1000);
5354
Float similarity = randomBoolean() ? null : randomFloat();
54-
55-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(field, vector, null, k, numCands, similarity);
55+
RescoreVectorBuilder rescoreVectorBuilder = randomBoolean()
56+
? null
57+
: new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false));
58+
59+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
60+
field,
61+
vector,
62+
null,
63+
k,
64+
numCands,
65+
rescoreVectorBuilder, similarity
66+
);
5667

5768
List<QueryBuilder> preFilterQueryBuilders = new ArrayList<>();
5869

server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
1919
import org.elasticsearch.search.builder.SearchSourceBuilder;
2020
import org.elasticsearch.search.rank.RankDoc;
21+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
2122
import org.elasticsearch.test.ESTestCase;
2223

2324
import java.io.IOException;
@@ -69,7 +70,7 @@ private List<RetrieverBuilder> innerRetrievers(QueryRewriteContext queryRewriteC
6970
null,
7071
randomInt(10),
7172
randomIntBetween(10, 100),
72-
randomFloat()
73+
randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat()
7374
);
7475
if (randomBoolean()) {
7576
knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public void testTelemetryForRRFRetriever() throws IOException {
102102

103103
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
104104
{
105-
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
105+
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null)));
106106
}
107107

108108
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ public void testRRFPagination() {
183183
);
184184
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
185185
// this one retrieves docs 2, 3, 6, and 7
186-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
186+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
187187
source.retriever(
188188
new RRFRetrieverBuilder(
189189
Arrays.asList(
@@ -233,7 +233,7 @@ public void testRRFWithAggs() {
233233
);
234234
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
235235
// this one retrieves docs 2, 3, 6, and 7
236-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
236+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
237237
source.retriever(
238238
new RRFRetrieverBuilder(
239239
Arrays.asList(
@@ -288,7 +288,7 @@ public void testRRFWithCollapse() {
288288
);
289289
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
290290
// this one retrieves docs 2, 3, 6, and 7
291-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
291+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
292292
source.retriever(
293293
new RRFRetrieverBuilder(
294294
Arrays.asList(
@@ -345,7 +345,7 @@ public void testRRFRetrieverWithCollapseAndAggs() {
345345
);
346346
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
347347
// this one retrieves docs 2, 3, 6, and 7
348-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
348+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
349349
source.retriever(
350350
new RRFRetrieverBuilder(
351351
Arrays.asList(
@@ -411,7 +411,7 @@ public void testMultipleRRFRetrievers() {
411411
);
412412
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
413413
// this one retrieves docs 2, 3, 6, and 7
414-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
414+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
415415
source.retriever(
416416
new RRFRetrieverBuilder(
417417
Arrays.asList(
@@ -430,7 +430,7 @@ public void testMultipleRRFRetrievers() {
430430
),
431431
// this one bring just doc 7 which should be ranked first eventually
432432
new CompoundRetrieverBuilder.RetrieverSource(
433-
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null),
433+
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null),
434434
null
435435
)
436436
),
@@ -477,7 +477,7 @@ public void testRRFExplainWithNamedRetrievers() {
477477
);
478478
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
479479
// this one retrieves docs 2, 3, 6, and 7
480-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
480+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
481481
source.retriever(
482482
new RRFRetrieverBuilder(
483483
Arrays.asList(
@@ -536,7 +536,7 @@ public void testRRFExplainWithAnotherNestedRRF() {
536536
);
537537
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
538538
// this one retrieves docs 2, 3, 6, and 7
539-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null);
539+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
540540

541541
RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder(
542542
Arrays.asList(
@@ -773,7 +773,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
773773
throw new IllegalStateException("Should not be called");
774774
}
775775
};
776-
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
776+
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, (RescoreVectorBuilder) args[6], null);
777777
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
778778
var rrf = new RRFRetrieverBuilder(
779779
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() {
149149
);
150150
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
151151
// this one retrieves docs 6
152-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null);
152+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null);
153153
source.retriever(
154154
new RRFRetrieverBuilder(
155155
Arrays.asList(

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public void testTelemetryForRRFRetriever() throws IOException {
103103

104104
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
105105
{
106-
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
106+
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null)));
107107
}
108108

109109
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
@@ -136,7 +136,7 @@ public void testTelemetryForRRFRetriever() throws IOException {
136136
new RRFRetrieverBuilder(
137137
Arrays.asList(
138138
new CompoundRetrieverBuilder.RetrieverSource(
139-
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null),
139+
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null),
140140
null
141141
),
142142
new CompoundRetrieverBuilder.RetrieverSource(

0 commit comments

Comments
 (0)