Skip to content

Commit bd920c5

Browse files
committed
Add tests
1 parent be76444 commit bd920c5

File tree

6 files changed

+101
-35
lines changed

6 files changed

+101
-35
lines changed

server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,9 @@ public void testKnnSearchAction() throws IOException {
417417
// how the action works (it builds a kNN query under the hood)
418418
float[] queryVector = randomVector();
419419
assertResponse(
420-
client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2),
420+
client().prepareSearch("index1", "index2")
421+
.setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null))
422+
.setSize(2),
421423
response -> {
422424
// The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard
423425
assertHitCount(response, 5 * 2);

server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException {
270270
new float[] { 1.0f, 2.0f, 3.0f },
271271
null,
272272
1,
273+
null,
273274
null
274275
);
275276
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.util.ArrayList;
4343
import java.util.List;
4444

45+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
4546
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
4647
import static org.hamcrest.Matchers.containsString;
4748
import static org.hamcrest.Matchers.equalTo;
@@ -56,7 +57,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
5657

5758
abstract DenseVectorFieldMapper.ElementType elementType();
5859

59-
abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity);
60+
abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(
61+
String fieldName,
62+
Integer k,
63+
int numCands,
64+
RescoreVectorBuilder rescoreVectorBuilder,
65+
Float similarity
66+
);
6067

6168
@Override
6269
protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
@@ -88,7 +95,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
8895
String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD;
8996
Integer k = randomBoolean() ? null : randomIntBetween(1, 100);
9097
int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000);
91-
KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat());
98+
KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(
99+
fieldName,
100+
k,
101+
numCands,
102+
randomRescoreVectorBuilder(),
103+
randomFloat()
104+
);
92105

93106
if (randomBoolean()) {
94107
List<QueryBuilder> filters = new ArrayList<>();
@@ -99,11 +112,24 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
99112
}
100113
queryBuilder.addFilterQueries(filters);
101114
}
115+
102116
return queryBuilder;
103117
}
104118

119+
protected RescoreVectorBuilder randomRescoreVectorBuilder() {
120+
if (randomBoolean()) {
121+
return null;
122+
}
123+
124+
return new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false));
125+
}
126+
105127
@Override
106128
protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
129+
if (queryBuilder.rescoreVectorBuilder() != null) {
130+
assertTrue(query instanceof org.apache.lucene.queries.function.FunctionScoreQuery);
131+
query = ((org.apache.lucene.queries.function.FunctionScoreQuery) query).getWrappedQuery();
132+
}
107133
if (queryBuilder.getVectorSimilarity() != null) {
108134
assertTrue(query instanceof VectorSimilarityQuery);
109135
Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery();
@@ -126,21 +152,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
126152
BooleanQuery booleanQuery = builder.build();
127153
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
128154
// The field should always be resolved to the concrete field
155+
Integer k = queryBuilder.k();
156+
Integer numCands = queryBuilder.numCands();
157+
if (queryBuilder.rescoreVectorBuilder() != null) {
158+
Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample();
159+
k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)));
160+
numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands);
161+
}
162+
129163
Query knnVectorQueryBuilt = switch (elementType()) {
130-
case BYTE, BIT -> new ESKnnByteVectorQuery(
131-
VECTOR_FIELD,
132-
queryBuilder.queryVector().asByteVector(),
133-
queryBuilder.k(),
134-
queryBuilder.numCands(),
135-
filterQuery
136-
);
137-
case FLOAT -> new ESKnnFloatVectorQuery(
138-
VECTOR_FIELD,
139-
queryBuilder.queryVector().asFloatVector(),
140-
queryBuilder.k(),
141-
queryBuilder.numCands(),
142-
filterQuery
143-
);
164+
case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery);
165+
case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery);
144166
};
145167
if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
146168
query = vectorSimilarityQuery.getInnerKnnQuery();
@@ -150,7 +172,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
150172

151173
public void testWrongDimension() {
152174
SearchExecutionContext context = createSearchExecutionContext();
153-
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null);
175+
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null);
154176
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
155177
assertThat(
156178
e.getMessage(),
@@ -160,15 +182,15 @@ public void testWrongDimension() {
160182

161183
public void testNonexistentField() {
162184
SearchExecutionContext context = createSearchExecutionContext();
163-
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null);
185+
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null);
164186
context.setAllowUnmappedFields(false);
165187
QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context));
166188
assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]"));
167189
}
168190

169191
public void testNonexistentFieldReturnEmpty() throws IOException {
170192
SearchExecutionContext context = createSearchExecutionContext();
171-
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null);
193+
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null);
172194
Query queryNone = query.doToQuery(context);
173195
assertThat(queryNone, instanceOf(MatchNoDocsQuery.class));
174196
}
@@ -180,6 +202,7 @@ public void testWrongFieldType() {
180202
new float[] { 1.0f, 1.0f, 1.0f },
181203
5,
182204
10,
205+
null,
183206
null
184207
);
185208
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
@@ -191,14 +214,14 @@ public void testNumCandsLessThanK() {
191214
int numCands = 3;
192215
IllegalArgumentException e = expectThrows(
193216
IllegalArgumentException.class,
194-
() -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null)
217+
() -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null)
195218
);
196219
assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
197220
}
198221

199222
@Override
200223
public void testValidOutput() {
201-
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null);
224+
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null);
202225
String expected = """
203226
{
204227
"knn" : {
@@ -213,7 +236,7 @@ public void testValidOutput() {
213236
}""";
214237
assertEquals(expected, query.toString());
215238

216-
KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null);
239+
KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null);
217240
String expected2 = """
218241
{
219242
"knn" : {
@@ -240,6 +263,7 @@ public void testMustRewrite() throws IOException {
240263
new float[] { 1.0f, 2.0f, 3.0f },
241264
VECTOR_DIMENSION,
242265
null,
266+
null,
243267
null
244268
);
245269
query.addFilterQuery(termQuery);
@@ -254,9 +278,14 @@ public void testMustRewrite() throws IOException {
254278
public void testBWCVersionSerializationFilters() throws IOException {
255279
KnnVectorQueryBuilder query = createTestQueryBuilder();
256280
VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
257-
KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null)
258-
.queryName(query.queryName())
259-
.boost(query.boost());
281+
KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(
282+
query.getFieldName(),
283+
vectorData,
284+
null,
285+
query.numCands(),
286+
null,
287+
null
288+
).queryName(query.queryName()).boost(query.boost());
260289
TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
261290
random(),
262291
TransportVersions.V_8_0_0,
@@ -268,10 +297,14 @@ public void testBWCVersionSerializationFilters() throws IOException {
268297
public void testBWCVersionSerializationSimilarity() throws IOException {
269298
KnnVectorQueryBuilder query = createTestQueryBuilder();
270299
VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
271-
KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null)
272-
.queryName(query.queryName())
273-
.boost(query.boost())
274-
.addFilterQueries(query.filterQueries());
300+
KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(
301+
query.getFieldName(),
302+
vectorData,
303+
null,
304+
query.numCands(),
305+
null,
306+
null
307+
).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
275308
assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0);
276309
}
277310

@@ -289,11 +322,29 @@ public void testBWCVersionSerializationQuery() throws IOException {
289322
vectorData,
290323
null,
291324
query.numCands(),
325+
null,
292326
similarity
293327
).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
294328
assertBWCSerialization(query, queryOlderVersion, differentQueryVersion);
295329
}
296330

331+
public void testBWCVersionSerializationRescoreVector() throws IOException {
332+
KnnVectorQueryBuilder query = createTestQueryBuilder();
333+
KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder(
334+
query.getFieldName(),
335+
query.queryVector(),
336+
query.k(),
337+
query.numCands(),
338+
null,
339+
query.getVectorSimilarity()
340+
).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
341+
assertBWCSerialization(
342+
query,
343+
queryNoRescoreVector,
344+
TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_8_0, TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)
345+
);
346+
}
347+
297348
private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException {
298349
assertSerialization(bwcQuery, version);
299350
try (BytesStreamOutput output = new BytesStreamOutput()) {

server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() {
1818
}
1919

2020
@Override
21-
protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) {
21+
protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(
22+
String fieldName,
23+
Integer k,
24+
int numCands,
25+
RescoreVectorBuilder rescoreVectorBuilder,
26+
Float similarity
27+
) {
2228
byte[] vector = new byte[VECTOR_DIMENSION];
2329
for (int i = 0; i < vector.length; i++) {
2430
vector[i] = randomByte();
2531
}
26-
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity);
32+
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
2733
}
2834
}

server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() {
1818
}
1919

2020
@Override
21-
KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) {
21+
KnnVectorQueryBuilder createKnnVectorQueryBuilder(
22+
String fieldName,
23+
Integer k,
24+
int numCands,
25+
RescoreVectorBuilder rescoreVectorBuilder,
26+
Float similarity
27+
) {
2228
float[] vector = new float[VECTOR_DIMENSION];
2329
for (int i = 0; i < vector.length; i++) {
2430
vector[i] = randomFloat();
2531
}
26-
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity);
32+
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
2733
}
2834
}

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public void testToQueryBuilder() {
167167
builder.addFilterQuery(filter);
168168
}
169169

170-
QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries)
170+
QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, null, similarity).addFilterQueries(filterQueries)
171171
.boost(boost);
172172
assertEquals(expected, builder.toQueryBuilder());
173173
}

0 commit comments

Comments
 (0)