Skip to content

Commit 155eee6

Browse files
authored
[8.15] Ensure vector similarity correctly limits inner_hits returned for nested kNN (#111363) (#111426)
* Ensure vector similarity correctly limits inner_hits returned for nested kNN (#111363) For nested kNN we support not only similarity thresholds, but also multi-passage search while retrieving more than one nearest passage. However, the inner_hits retrieved for the kNN search would ignore the restricted similarity. Meaning, the inner hits would return all passages, not just the ones within the limited similarity and this is confusing. closes: #111093 (cherry picked from commit 69c9697) * fixing for backport * adj for backport * fix compilation for tests
1 parent 3f30e38 commit 155eee6

File tree

14 files changed

+180
-44
lines changed

14 files changed

+180
-44
lines changed

docs/changelog/111363.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 111363
2+
summary: Ensure vector similarity correctly limits `inner_hits` returned for nested
3+
kNN
4+
area: Vector Search
5+
type: bug
6+
issues: [111093]

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,53 @@ setup:
411411

412412
- match: {hits.total.value: 1}
413413
- match: {hits.hits.0._id: "2"}
414+
---
415+
"nested Knn search with required similarity appropriately filters inner_hits":
416+
- requires:
417+
cluster_features: "gte_v8.15.0"
418+
reason: 'bugfix for 8.15'
419+
420+
- do:
421+
search:
422+
index: test
423+
body:
424+
query:
425+
nested:
426+
path: nested
427+
inner_hits:
428+
size: 3
429+
_source: false
430+
fields:
431+
- nested.paragraph_id
432+
query:
433+
knn:
434+
field: nested.vector
435+
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
436+
num_candidates: 3
437+
similarity: 10.5
438+
439+
- match: {hits.total.value: 1}
440+
- match: {hits.hits.0._id: "2"}
441+
- length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
442+
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
443+
444+
- do:
445+
search:
446+
index: test
447+
body:
448+
knn:
449+
field: nested.vector
450+
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
451+
num_candidates: 3
452+
k: 3
453+
similarity: 10.5
454+
inner_hits:
455+
size: 3
456+
_source: false
457+
fields:
458+
- nested.paragraph_id
459+
460+
- match: {hits.total.value: 1}
461+
- match: {hits.hits.0._id: "2"}
462+
- length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
463+
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ static TransportVersion def(int id) {
210210
public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0);
211211
public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0);
212212
public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER_BACKPORT_8_15 = def(8_702_00_1);
213+
public static final TransportVersion FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15 = def(8_702_00_2);
213214

214215
/*
215216
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
155155
QueryBuilder query = new KnnScoreDocQueryBuilder(
156156
scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS),
157157
source.knnSearch().get(i).getField(),
158-
source.knnSearch().get(i).getQueryVector()
158+
source.knnSearch().get(i).getQueryVector(),
159+
source.knnSearch().get(i).getSimilarity()
159160
).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName());
160161
if (nestedPath != null) {
161162
query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,17 +1711,21 @@ public Query termQuery(Object value, SearchExecutionContext context) {
17111711
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
17121712
}
17131713

1714-
public Query createExactKnnQuery(VectorData queryVector) {
1714+
public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) {
17151715
if (isIndexed() == false) {
17161716
throw new IllegalArgumentException(
17171717
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
17181718
);
17191719
}
1720-
return switch (elementType) {
1720+
Query knnQuery = switch (elementType) {
17211721
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
17221722
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
17231723
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
17241724
};
1725+
if (vectorSimilarity != null) {
1726+
knnQuery = new VectorSimilarityQuery(knnQuery, vectorSimilarity, similarity.score(vectorSimilarity, elementType, dims));
1727+
}
1728+
return knnQuery;
17251729
}
17261730

17271731
private Query createExactKnnBitQuery(byte[] queryVector) {

server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,18 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
3232
public static final String NAME = "exact_knn";
3333
private final String field;
3434
private final VectorData query;
35+
private final Float vectorSimilarity;
3536

3637
/**
3738
* Creates a query builder.
3839
*
3940
* @param query the query vector
4041
* @param field the field that was used for the kNN query
4142
*/
42-
public ExactKnnQueryBuilder(float[] query, String field) {
43-
this(VectorData.fromFloats(query), field);
44-
}
45-
46-
/**
47-
* Creates a query builder.
48-
*
49-
* @param query the query vector
50-
* @param field the field that was used for the kNN query
51-
*/
52-
public ExactKnnQueryBuilder(VectorData query, String field) {
43+
public ExactKnnQueryBuilder(VectorData query, String field, Float vectorSimilarity) {
5344
this.query = query;
5445
this.field = field;
46+
this.vectorSimilarity = vectorSimilarity;
5547
}
5648

5749
public ExactKnnQueryBuilder(StreamInput in) throws IOException {
@@ -62,6 +54,11 @@ public ExactKnnQueryBuilder(StreamInput in) throws IOException {
6254
this.query = VectorData.fromFloats(in.readFloatArray());
6355
}
6456
this.field = in.readString();
57+
if (in.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
58+
this.vectorSimilarity = in.readOptionalFloat();
59+
} else {
60+
this.vectorSimilarity = null;
61+
}
6562
}
6663

6764
String getField() {
@@ -72,6 +69,10 @@ VectorData getQuery() {
7269
return query;
7370
}
7471

72+
Float vectorSimilarity() {
73+
return vectorSimilarity;
74+
}
75+
7576
@Override
7677
public String getWriteableName() {
7778
return NAME;
@@ -85,13 +86,19 @@ protected void doWriteTo(StreamOutput out) throws IOException {
8586
out.writeFloatArray(query.asFloatVector());
8687
}
8788
out.writeString(field);
89+
if (out.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
90+
out.writeOptionalFloat(vectorSimilarity);
91+
}
8892
}
8993

9094
@Override
9195
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
9296
builder.startObject(NAME);
9397
builder.field("query", query);
9498
builder.field("field", field);
99+
if (vectorSimilarity != null) {
100+
builder.field("similarity", vectorSimilarity);
101+
}
95102
boostAndQueryNameToXContent(builder);
96103
builder.endObject();
97104
}
@@ -108,17 +115,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
108115
);
109116
}
110117
final DenseVectorFieldMapper.DenseVectorFieldType vectorFieldType = (DenseVectorFieldMapper.DenseVectorFieldType) fieldType;
111-
return vectorFieldType.createExactKnnQuery(query);
118+
return vectorFieldType.createExactKnnQuery(query, vectorSimilarity);
112119
}
113120

114121
@Override
115122
protected boolean doEquals(ExactKnnQueryBuilder other) {
116-
return field.equals(other.field) && Objects.equals(query, other.query);
123+
return field.equals(other.field) && Objects.equals(query, other.query) && Objects.equals(vectorSimilarity, other.vectorSimilarity);
117124
}
118125

119126
@Override
120127
protected int doHashCode() {
121-
return Objects.hash(field, Objects.hashCode(query));
128+
return Objects.hash(field, Objects.hashCode(query), vectorSimilarity);
122129
}
123130

124131
@Override

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,19 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
3737
private final ScoreDoc[] scoreDocs;
3838
private final String fieldName;
3939
private final VectorData queryVector;
40+
private final Float vectorSimilarity;
4041

4142
/**
4243
* Creates a query builder.
4344
*
4445
* @param scoreDocs the docs and scores this query should match. The array must be
4546
* sorted in order of ascending doc IDs.
4647
*/
47-
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, float[] queryVector) {
48-
this(scoreDocs, fieldName, VectorData.fromFloats(queryVector));
49-
}
50-
51-
/**
52-
* Creates a query builder.
53-
*
54-
* @param scoreDocs the docs and scores this query should match. The array must be
55-
* sorted in order of ascending doc IDs.
56-
*/
57-
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector) {
48+
public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector, Float vectorSimilarity) {
5849
this.scoreDocs = scoreDocs;
5950
this.fieldName = fieldName;
6051
this.queryVector = queryVector;
52+
this.vectorSimilarity = vectorSimilarity;
6153
}
6254

6355
public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
@@ -78,6 +70,11 @@ public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
7870
this.fieldName = null;
7971
this.queryVector = null;
8072
}
73+
if (in.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
74+
this.vectorSimilarity = in.readOptionalFloat();
75+
} else {
76+
this.vectorSimilarity = null;
77+
}
8178
}
8279

8380
@Override
@@ -97,6 +94,10 @@ VectorData queryVector() {
9794
return queryVector;
9895
}
9996

97+
Float vectorSimilarity() {
98+
return vectorSimilarity;
99+
}
100+
100101
@Override
101102
protected void doWriteTo(StreamOutput out) throws IOException {
102103
out.writeArray(Lucene::writeScoreDoc, scoreDocs);
@@ -113,6 +114,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
113114
out.writeBoolean(false);
114115
}
115116
}
117+
if (out.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) {
118+
out.writeOptionalFloat(vectorSimilarity);
119+
}
116120
}
117121

118122
@Override
@@ -129,6 +133,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
129133
if (queryVector != null) {
130134
builder.field("query", queryVector);
131135
}
136+
if (vectorSimilarity != null) {
137+
builder.field("similarity", vectorSimilarity);
138+
}
132139
boostAndQueryNameToXContent(builder);
133140
builder.endObject();
134141
}
@@ -154,7 +161,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
154161
return new MatchNoneQueryBuilder("The \"" + getName() + "\" query was rewritten to a \"match_none\" query.");
155162
}
156163
if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && queryVector != null && fieldName != null) {
157-
return new ExactKnnQueryBuilder(queryVector, fieldName);
164+
return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
158165
}
159166
return super.doRewrite(queryRewriteContext);
160167
}
@@ -193,7 +200,9 @@ protected boolean doEquals(KnnScoreDocQueryBuilder other) {
193200
return false;
194201
}
195202
}
196-
return Objects.equals(fieldName, other.fieldName) && Objects.equals(queryVector, other.queryVector);
203+
return Objects.equals(fieldName, other.fieldName)
204+
&& Objects.equals(queryVector, other.queryVector)
205+
&& Objects.equals(vectorSimilarity, other.vectorSimilarity);
197206
}
198207

199208
@Override
@@ -203,7 +212,7 @@ protected int doHashCode() {
203212
int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
204213
result = 31 * result + hashCode;
205214
}
206-
return Objects.hash(result, fieldName, Objects.hashCode(queryVector));
215+
return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector));
207216
}
208217

209218
@Override

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ public KnnVectorQueryBuilder toQueryBuilder() {
403403
.addFilterQueries(filterQueries);
404404
}
405405

406+
public Float getSimilarity() {
407+
return similarity;
408+
}
409+
406410
@Override
407411
public boolean equals(Object o) {
408412
if (this == o) return true;

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
400400
).queryName(queryName).addFilterQueries(filterQueries);
401401
}
402402
if (ctx.convertToInnerHitsRewriteContext() != null) {
403-
return new ExactKnnQueryBuilder(queryVector, fieldName).boost(boost).queryName(queryName);
403+
return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName);
404404
}
405405
boolean changed = false;
406406
List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());

server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.search.rank.TestRankBuilder;
3636
import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder;
3737
import org.elasticsearch.search.vectors.KnnSearchBuilder;
38+
import org.elasticsearch.search.vectors.VectorData;
3839
import org.elasticsearch.test.ESTestCase;
3940
import org.elasticsearch.test.InternalAggregationTestCase;
4041
import org.elasticsearch.transport.Transport;
@@ -351,12 +352,14 @@ public void testRewriteShardSearchRequestWithRank() {
351352
KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder(
352353
new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) },
353354
"vector",
354-
new float[] { 0.0f }
355+
VectorData.fromFloats(new float[] { 0.0f }),
356+
null
355357
);
356358
KnnScoreDocQueryBuilder ksdqb1 = new KnnScoreDocQueryBuilder(
357359
new ScoreDoc[] { new ScoreDoc(1, 2.0f, 1) },
358360
"vector2",
359-
new float[] { 0.0f }
361+
VectorData.fromFloats(new float[] { 0.0f }),
362+
null
360363
);
361364
assertEquals(
362365
List.of(bm25, ksdqb0, ksdqb1),

0 commit comments

Comments
 (0)