Skip to content

Commit be76444

Browse files
committed
Change API to use "rescore": {"oversample": 1.0}
1 parent df06716 commit be76444

File tree

7 files changed

+171
-46
lines changed

7 files changed

+171
-46
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public void testTelemetryForRetrievers() throws IOException {
9898
{
9999
performSearch(
100100
new SearchSourceBuilder().retriever(
101-
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
101+
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))
102102
)
103103
);
104104
}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,8 +2097,10 @@ private Query createKnnByteQuery(
20972097
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
20982098
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
20992099
}
2100-
int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2101-
int adjustedNumCands = Math.max(adjustedK, numCands);
2100+
Integer adjustedK = k == null || rescoreOversample == null
2101+
? null
2102+
: Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2103+
int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands);
21022104

21032105
Query knnQuery = parentFilter != null
21042106
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
@@ -2149,8 +2151,10 @@ && isNotUnitVector(squaredMagnitude)) {
21492151
}
21502152
}
21512153

2152-
int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2153-
int adjustedNumCands = Math.max(adjustedK, numCands);
2154+
Integer adjustedK = k == null || rescoreOversample == null
2155+
? k
2156+
: Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)));
2157+
int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands);
21542158
Query knnQuery = parentFilter != null
21552159
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
21562160
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ public KnnVectorQueryBuilder toQueryBuilder() {
407407
if (queryVectorBuilder != null) {
408408
throw new IllegalArgumentException("missing rewrite");
409409
}
410-
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, similarity).boost(boost)
410+
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, similarity).boost(boost)
411411
.queryName(queryName)
412412
.addFilterQueries(filterQueries);
413413
}

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ public KnnVectorQueryBuilder toQueryBuilder() {
256256
if (numCands > NUM_CANDS_LIMIT) {
257257
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
258258
}
259-
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null);
259+
return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null);
260260
}
261261

262262
@Override

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
6767
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
6868
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
6969
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
70-
public static final ParseField RESCORE_VECTOR_OVERSAMPLE = new ParseField("rescore_vector_oversample");
7170
public static final ParseField FILTER_FIELD = new ParseField("filter");
7271
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
72+
public static final ParseField RESCORE_FIELD = new ParseField("rescore");
7373

7474
public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
7575
"knn",
@@ -80,8 +80,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
8080
null,
8181
(Integer) args[2],
8282
(Integer) args[3],
83-
(Float) args[4],
84-
(Float) args[6]
83+
(RescoreVectorBuilder) args[6],
84+
(Float) args[4]
8585
)
8686
);
8787

@@ -101,7 +101,12 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
101101
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
102102
QUERY_VECTOR_BUILDER_FIELD
103103
);
104-
PARSER.declareFloat(optionalConstructorArg(), RESCORE_VECTOR_OVERSAMPLE);
104+
PARSER.declareField(
105+
optionalConstructorArg(),
106+
(p, c) -> RescoreVectorBuilder.fromXContent(p),
107+
RESCORE_FIELD,
108+
ObjectParser.ValueType.OBJECT_OR_NULL
109+
);
105110
PARSER.declareFieldArray(
106111
KnnVectorQueryBuilder::addFilterQueries,
107112
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
@@ -123,10 +128,17 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
123128
private final Float vectorSimilarity;
124129
private final QueryVectorBuilder queryVectorBuilder;
125130
private final Supplier<float[]> queryVectorSupplier;
126-
private final Float rescoreOversample;
131+
private final RescoreVectorBuilder rescoreVectorBuilder;
127132

128-
public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
129-
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity, null);
133+
public KnnVectorQueryBuilder(
134+
String fieldName,
135+
float[] queryVector,
136+
Integer k,
137+
Integer numCands,
138+
RescoreVectorBuilder rescoreVectorBuilder,
139+
Float vectorSimilarity
140+
) {
141+
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity);
130142
}
131143

132144
public KnnVectorQueryBuilder(
@@ -136,27 +148,29 @@ public KnnVectorQueryBuilder(
136148
Integer numCands,
137149
Float vectorSimilarity
138150
) {
139-
this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity, null);
151+
this(fieldName, null, queryVectorBuilder, null, k, numCands, null, vectorSimilarity);
140152
}
141153

142-
public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
143-
this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity);
144-
}
145-
146-
public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
147-
this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity);
154+
public KnnVectorQueryBuilder(
155+
String fieldName,
156+
byte[] queryVector,
157+
Integer k,
158+
Integer numCands,
159+
RescoreVectorBuilder rescoreVectorBuilder,
160+
Float vectorSimilarity
161+
) {
162+
this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity);
148163
}
149164

150-
private KnnVectorQueryBuilder(
165+
public KnnVectorQueryBuilder(
151166
String fieldName,
152167
VectorData queryVector,
153-
QueryVectorBuilder queryVectorBuilder,
154-
Supplier<float[]> queryVectorSupplier,
155168
Integer k,
156169
Integer numCands,
170+
RescoreVectorBuilder rescoreVectorBuilder,
157171
Float vectorSimilarity
158172
) {
159-
this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, null);
173+
this(fieldName, queryVector, null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity);
160174
}
161175

162176
private KnnVectorQueryBuilder(
@@ -166,8 +180,8 @@ private KnnVectorQueryBuilder(
166180
Supplier<float[]> queryVectorSupplier,
167181
Integer k,
168182
Integer numCands,
169-
Float vectorSimilarity,
170-
Float rescoreOversample
183+
RescoreVectorBuilder rescoreVectorBuilder,
184+
Float vectorSimilarity
171185
) {
172186
if (k != null && k < 1) {
173187
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
@@ -204,7 +218,7 @@ private KnnVectorQueryBuilder(
204218
this.vectorSimilarity = vectorSimilarity;
205219
this.queryVectorBuilder = queryVectorBuilder;
206220
this.queryVectorSupplier = queryVectorSupplier;
207-
this.rescoreOversample = rescoreOversample;
221+
this.rescoreVectorBuilder = rescoreVectorBuilder;
208222
}
209223

210224
public KnnVectorQueryBuilder(StreamInput in) throws IOException {
@@ -246,9 +260,9 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
246260
this.queryVectorBuilder = null;
247261
}
248262
if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
249-
this.rescoreOversample = in.readOptionalFloat();
263+
this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new);
250264
} else {
251-
this.rescoreOversample = null;
265+
this.rescoreVectorBuilder = null;
252266
}
253267

254268
this.queryVectorSupplier = null;
@@ -276,10 +290,6 @@ public Integer numCands() {
276290
return numCands;
277291
}
278292

279-
public Float rescoreOversample() {
280-
return rescoreOversample;
281-
}
282-
283293
public List<QueryBuilder> filterQueries() {
284294
return filterQueries;
285295
}
@@ -289,6 +299,10 @@ public QueryVectorBuilder queryVectorBuilder() {
289299
return queryVectorBuilder;
290300
}
291301

302+
public RescoreVectorBuilder rescoreVectorBuilder() {
303+
return rescoreVectorBuilder;
304+
}
305+
292306
public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) {
293307
Objects.requireNonNull(filterQuery);
294308
this.filterQueries.add(filterQuery);
@@ -356,7 +370,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
356370
out.writeOptionalNamedWriteable(queryVectorBuilder);
357371
}
358372
if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
359-
out.writeOptionalFloat(rescoreOversample);
373+
out.writeOptionalWriteable(rescoreVectorBuilder);
360374
}
361375
}
362376

@@ -391,6 +405,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
391405
}
392406
builder.endArray();
393407
}
408+
if (rescoreVectorBuilder != null) {
409+
builder.startObject(RESCORE_FIELD.getPreferredName());
410+
rescoreVectorBuilder.toXContent(builder, params);
411+
builder.endObject();
412+
}
394413
boostAndQueryNameToXContent(builder);
395414
builder.endObject();
396415
}
@@ -406,7 +425,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
406425
if (queryVectorSupplier.get() == null) {
407426
return this;
408427
}
409-
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost)
428+
return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, rescoreVectorBuilder, vectorSimilarity)
429+
.boost(boost)
410430
.queryName(queryName)
411431
.addFilterQueries(filterQueries);
412432
}
@@ -428,9 +448,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
428448
}
429449
ll.onResponse(null);
430450
})));
431-
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost(
432-
boost
433-
).queryName(queryName).addFilterQueries(filterQueries);
451+
return new KnnVectorQueryBuilder(
452+
fieldName,
453+
queryVector,
454+
queryVectorBuilder,
455+
toSet::get,
456+
k,
457+
numCands,
458+
rescoreVectorBuilder,
459+
vectorSimilarity
460+
).boost(boost).queryName(queryName).addFilterQueries(filterQueries);
434461
}
435462
if (ctx.convertToInnerHitsRewriteContext() != null) {
436463
return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName);
@@ -448,10 +475,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
448475
rewrittenQueries.add(rewrittenQuery);
449476
}
450477
if (changed) {
451-
return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity)
452-
.boost(boost)
453-
.queryName(queryName)
454-
.addFilterQueries(rewrittenQueries);
478+
return new KnnVectorQueryBuilder(
479+
fieldName,
480+
queryVector,
481+
queryVectorBuilder,
482+
queryVectorSupplier,
483+
k,
484+
numCands,
485+
rescoreVectorBuilder,
486+
vectorSimilarity
487+
).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries);
455488
}
456489
return this;
457490
}
@@ -495,6 +528,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
495528

496529
DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
497530
String parentPath = context.nestedLookup().getNestedParent(fieldName);
531+
Float rescoreOversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
498532

499533
if (parentPath != null) {
500534
final BitSetProducer parentBitSet;
@@ -550,7 +584,7 @@ protected int doHashCode() {
550584
filterQueries,
551585
vectorSimilarity,
552586
queryVectorBuilder,
553-
rescoreOversample
587+
rescoreVectorBuilder
554588
);
555589
}
556590

@@ -563,7 +597,7 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
563597
&& Objects.equals(filterQueries, other.filterQueries)
564598
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
565599
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder)
566-
&& Objects.equals(rescoreOversample, other.rescoreOversample);
600+
&& Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder);
567601
}
568602

569603
@Override
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.common.io.stream.Writeable;
15+
import org.elasticsearch.xcontent.ConstructingObjectParser;
16+
import org.elasticsearch.xcontent.ParseField;
17+
import org.elasticsearch.xcontent.ToXContentObject;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
21+
import java.io.IOException;
22+
import java.util.Objects;
23+
24+
public class RescoreVectorBuilder implements Writeable, ToXContentObject {
25+
26+
public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample");
27+
public static final int MIN_OVERSAMPLE = 1;
28+
private static final ConstructingObjectParser<RescoreVectorBuilder, Void> PARSER = new ConstructingObjectParser<>(
29+
"rescore",
30+
args -> new RescoreVectorBuilder((Float) args[0])
31+
);
32+
33+
static {
34+
PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), OVERSAMPLE_FIELD);
35+
}
36+
37+
// Oversample is required as of now as it is the only field in the rescore vector
38+
// that may change in the future, so we treat it as optional
39+
private final Float oversample;
40+
41+
public RescoreVectorBuilder(Float oversample) {
42+
Objects.requireNonNull(oversample, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set");
43+
if (oversample <= MIN_OVERSAMPLE) {
44+
throw new IllegalArgumentException("[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE);
45+
}
46+
this.oversample = oversample;
47+
}
48+
49+
public RescoreVectorBuilder(StreamInput in) throws IOException {
50+
this.oversample = in.readOptionalFloat();
51+
}
52+
53+
@Override
54+
public void writeTo(StreamOutput out) throws IOException {
55+
out.writeOptionalFloat(oversample);
56+
}
57+
58+
@Override
59+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
60+
if (oversample != null) {
61+
builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample);
62+
}
63+
64+
return builder;
65+
}
66+
67+
public static RescoreVectorBuilder fromXContent(XContentParser parser) {
68+
return PARSER.apply(parser, null);
69+
}
70+
71+
@Override
72+
public boolean equals(Object o) {
73+
if (this == o) return true;
74+
if (o == null || getClass() != o.getClass()) return false;
75+
RescoreVectorBuilder that = (RescoreVectorBuilder) o;
76+
return Objects.equals(oversample, that.oversample);
77+
}
78+
79+
@Override
80+
public int hashCode() {
81+
return Objects.hashCode(oversample);
82+
}
83+
84+
public Float oversample() {
85+
return oversample;
86+
}
87+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost
556556
);
557557
}
558558

559-
yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null);
559+
yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null,null);
560560
}
561561
default -> throw new IllegalStateException(
562562
"Field ["

0 commit comments

Comments
 (0)