Skip to content

Commit 2a9e300

Browse files
committed
Add rescore vector builder to KnnSearchBuilder
1 parent b44ec48 commit 2a9e300

File tree

14 files changed

+249
-68
lines changed

14 files changed

+249
-68
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void testSimpleNested() throws Exception {
6969

7070
assertResponse(
7171
prepareSearch("test").setKnnSearch(
72-
List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null).innerHit(new InnerHitBuilder()))
72+
List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null, null).innerHit(new InnerHitBuilder()))
7373
).setAllowPartialSearchResults(false),
7474
response -> assertThat(response.getHits().getHits().length, greaterThan(0))
7575
);

server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.search.profile.query.CollectorResult;
2020
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
2121
import org.elasticsearch.search.vectors.KnnSearchBuilder;
22+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
2223
import org.elasticsearch.test.ESIntegTestCase;
2324
import org.elasticsearch.xcontent.XContentFactory;
2425

@@ -71,6 +72,7 @@ public void testProfileDfs() throws Exception {
7172
new float[] { randomFloat(), randomFloat(), randomFloat() },
7273
randomIntBetween(5, 10),
7374
50,
75+
randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)),
7476
randomBoolean() ? null : randomFloat()
7577
);
7678
if (randomBoolean()) {

server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ public void testTelemetryForRetrievers() throws IOException {
8484

8585
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
8686
{
87-
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
87+
performSearch(
88+
new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null))
89+
);
8890
}
8991

9092
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
@@ -112,7 +114,9 @@ public void testTelemetryForRetrievers() throws IOException {
112114
// search#5 - t
113115
// his will record 1 entry for "knn" in `sections`
114116
{
115-
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
117+
performSearch(
118+
new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)))
119+
);
116120
}
117121

118122
// search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
5656
public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD;
5757
public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
5858
public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits");
59+
public static final ParseField RESCORE_FIELD = new ParseField("rescore");
5960

6061
@SuppressWarnings("unchecked")
6162
private static final ConstructingObjectParser<KnnSearchBuilder.Builder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> {
@@ -65,7 +66,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
6566
.queryVectorBuilder((QueryVectorBuilder) args[4])
6667
.k((Integer) args[2])
6768
.numCandidates((Integer) args[3])
68-
.similarity((Float) args[5]);
69+
.similarity((Float) args[5])
70+
.rescoreVectorBuilder((RescoreVectorBuilder) args[6]);
6971
});
7072

7173
static {
@@ -78,13 +80,18 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
7880
);
7981
PARSER.declareInt(optionalConstructorArg(), K_FIELD);
8082
PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
81-
8283
PARSER.declareNamedObject(
8384
optionalConstructorArg(),
8485
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
8586
QUERY_VECTOR_BUILDER_FIELD
8687
);
8788
PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY);
89+
PARSER.declareField(
90+
optionalConstructorArg(),
91+
(p, c) -> RescoreVectorBuilder.fromXContent(p),
92+
RESCORE_FIELD,
93+
ObjectParser.ValueType.OBJECT_OR_NULL
94+
);
8895
PARSER.declareFieldArray(
8996
KnnSearchBuilder.Builder::addFilterQueries,
9097
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
@@ -116,6 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw
116123
String queryName;
117124
float boost = DEFAULT_BOOST;
118125
InnerHitBuilder innerHitBuilder;
126+
final RescoreVectorBuilder rescoreVectorBuilder;
119127

120128
/**
121129
* Defines a kNN search.
@@ -124,14 +132,23 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw
124132
* @param queryVector the query vector
125133
* @param k the final number of nearest neighbors to return as top hits
126134
* @param numCands the number of nearest neighbor candidates to consider per shard
135+
* @param rescoreVectorBuilder rescore vector information
127136
*/
128-
public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) {
137+
public KnnSearchBuilder(
138+
String field,
139+
float[] queryVector,
140+
int k,
141+
int numCands,
142+
RescoreVectorBuilder rescoreVectorBuilder,
143+
Float similarity
144+
) {
129145
this(
130146
field,
131147
Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)),
132148
null,
133149
k,
134150
numCands,
151+
rescoreVectorBuilder,
135152
similarity
136153
);
137154
}
@@ -144,8 +161,15 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands,
144161
* @param k the final number of nearest neighbors to return as top hits
145162
* @param numCands the number of nearest neighbor candidates to consider per shard
146163
*/
147-
public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) {
148-
this(field, queryVector, null, k, numCands, similarity);
164+
public KnnSearchBuilder(
165+
String field,
166+
VectorData queryVector,
167+
int k,
168+
int numCands,
169+
RescoreVectorBuilder rescoreVectorBuilder,
170+
Float similarity
171+
) {
172+
this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity);
149173
}
150174

151175
/**
@@ -156,13 +180,21 @@ public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCand
156180
* @param k the final number of nearest neighbors to return as top hits
157181
* @param numCands the number of nearest neighbor candidates to consider per shard
158182
*/
159-
public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, Float similarity) {
183+
public KnnSearchBuilder(
184+
String field,
185+
QueryVectorBuilder queryVectorBuilder,
186+
int k,
187+
int numCands,
188+
RescoreVectorBuilder rescoreVectorBuilder,
189+
Float similarity
190+
) {
160191
this(
161192
field,
162193
null,
163194
Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())),
164195
k,
165196
numCands,
197+
rescoreVectorBuilder,
166198
similarity
167199
);
168200
}
@@ -173,16 +205,30 @@ public KnnSearchBuilder(
173205
QueryVectorBuilder queryVectorBuilder,
174206
int k,
175207
int numCands,
208+
RescoreVectorBuilder rescoreVectorBuilder,
176209
Float similarity
177210
) {
178-
this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST);
211+
this(
212+
field,
213+
queryVectorBuilder,
214+
queryVector,
215+
new ArrayList<>(),
216+
k,
217+
numCands,
218+
rescoreVectorBuilder,
219+
similarity,
220+
null,
221+
null,
222+
DEFAULT_BOOST
223+
);
179224
}
180225

181226
private KnnSearchBuilder(
182227
String field,
183228
Supplier<float[]> querySupplier,
184229
Integer k,
185230
Integer numCands,
231+
RescoreVectorBuilder rescoreVectorBuilder,
186232
List<QueryBuilder> filterQueries,
187233
Float similarity
188234
) {
@@ -194,6 +240,7 @@ private KnnSearchBuilder(
194240
this.filterQueries = filterQueries;
195241
this.querySupplier = querySupplier;
196242
this.similarity = similarity;
243+
this.rescoreVectorBuilder = rescoreVectorBuilder;
197244
}
198245

199246
private KnnSearchBuilder(
@@ -203,6 +250,7 @@ private KnnSearchBuilder(
203250
List<QueryBuilder> filterQueries,
204251
int k,
205252
int numCandidates,
253+
RescoreVectorBuilder rescoreVectorBuilder,
206254
Float similarity,
207255
InnerHitBuilder innerHitBuilder,
208256
String queryName,
@@ -242,6 +290,7 @@ private KnnSearchBuilder(
242290
this.queryVectorBuilder = queryVectorBuilder;
243291
this.k = k;
244292
this.numCands = numCandidates;
293+
this.rescoreVectorBuilder = rescoreVectorBuilder;
245294
this.innerHitBuilder = innerHitBuilder;
246295
this.similarity = similarity;
247296
this.queryName = queryName;
@@ -280,6 +329,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException {
280329
if (in.getTransportVersion().onOrAfter(V_8_11_X)) {
281330
this.innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new);
282331
}
332+
if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) {
333+
this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new);
334+
} else {
335+
this.rescoreVectorBuilder = null;
336+
}
283337
}
284338

285339
public int k() {
@@ -290,6 +344,10 @@ public int getNumCands() {
290344
return numCands;
291345
}
292346

347+
public RescoreVectorBuilder getRescoreVectorBuilder() {
348+
return rescoreVectorBuilder;
349+
}
350+
293351
public QueryVectorBuilder getQueryVectorBuilder() {
294352
return queryVectorBuilder;
295353
}
@@ -358,7 +416,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
358416
if (querySupplier.get() == null) {
359417
return this;
360418
}
361-
return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost)
419+
return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, rescoreVectorBuilder, similarity).boost(boost)
362420
.queryName(queryName)
363421
.addFilterQueries(filterQueries)
364422
.innerHit(innerHitBuilder);
@@ -381,7 +439,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
381439
}
382440
ll.onResponse(null);
383441
})));
384-
return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost)
442+
return new KnnSearchBuilder(field, toSet::get, k, numCands, rescoreVectorBuilder, filterQueries, similarity).boost(boost)
385443
.queryName(queryName)
386444
.innerHit(innerHitBuilder);
387445
}
@@ -395,7 +453,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
395453
rewrittenQueries.add(rewrittenQuery);
396454
}
397455
if (changed) {
398-
return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost)
456+
return new KnnSearchBuilder(field, queryVector, k, numCands, rescoreVectorBuilder, similarity).boost(boost)
399457
.queryName(queryName)
400458
.addFilterQueries(rewrittenQueries)
401459
.innerHit(innerHitBuilder);
@@ -407,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() {
407465
if (queryVectorBuilder != null) {
408466
throw new IllegalArgumentException("missing rewrite");
409467
}
410-
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, similarity).boost(boost)
468+
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost)
411469
.queryName(queryName)
412470
.addFilterQueries(filterQueries);
413471
}
@@ -423,6 +481,7 @@ public boolean equals(Object o) {
423481
KnnSearchBuilder that = (KnnSearchBuilder) o;
424482
return k == that.k
425483
&& numCands == that.numCands
484+
&& Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder)
426485
&& Objects.equals(field, that.field)
427486
&& Objects.equals(queryVector, that.queryVector)
428487
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
@@ -442,6 +501,7 @@ public int hashCode() {
442501
numCands,
443502
querySupplier,
444503
queryVectorBuilder,
504+
rescoreVectorBuilder,
445505
similarity,
446506
Objects.hashCode(queryVector),
447507
Objects.hashCode(filterQueries),
@@ -486,6 +546,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
486546
if (queryName != null) {
487547
builder.field(NAME_FIELD.getPreferredName(), queryName);
488548
}
549+
if (rescoreVectorBuilder != null) {
550+
builder.startObject(RESCORE_FIELD.getPreferredName());
551+
rescoreVectorBuilder.toXContent(builder, params);
552+
builder.endObject();
553+
}
489554

490555
return builder;
491556
}
@@ -526,6 +591,9 @@ public void writeTo(StreamOutput out) throws IOException {
526591
if (out.getTransportVersion().onOrAfter(V_8_11_X)) {
527592
out.writeOptionalWriteable(innerHitBuilder);
528593
}
594+
if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) {
595+
out.writeOptionalWriteable(rescoreVectorBuilder);
596+
}
529597
}
530598

531599
public static class Builder {
@@ -540,6 +608,7 @@ public static class Builder {
540608
private String queryName;
541609
private float boost = DEFAULT_BOOST;
542610
private InnerHitBuilder innerHitBuilder;
611+
private RescoreVectorBuilder rescoreVectorBuilder;
543612

544613
public Builder addFilterQueries(List<QueryBuilder> filterQueries) {
545614
Objects.requireNonNull(filterQueries);
@@ -592,6 +661,11 @@ public Builder similarity(Float similarity) {
592661
return this;
593662
}
594663

664+
public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) {
665+
this.rescoreVectorBuilder = rescoreVectorBuilder;
666+
return this;
667+
}
668+
595669
public KnnSearchBuilder build(int size) {
596670
int requestSize = size < 0 ? DEFAULT_SIZE : size;
597671
int adjustedK = k == null ? requestSize : k;
@@ -605,6 +679,7 @@ public KnnSearchBuilder build(int size) {
605679
filterQueries,
606680
adjustedK,
607681
adjustedNumCandidates,
682+
rescoreVectorBuilder,
608683
similarity,
609684
innerHitBuilder,
610685
queryName,

server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ public void testRewriteShardSearchRequestWithRank() {
344344
SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25)
345345
.knnSearch(
346346
List.of(
347-
new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null),
348-
new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null)
347+
new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null, null),
348+
new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null, null)
349349
)
350350
)
351351
.rankBuilder(new TestRankBuilder(100));

0 commit comments

Comments
 (0)