Skip to content

Commit 3ef07fa

Browse files
committed
Add index types to vector query builder tests
1 parent ddc6094 commit 3ef07fa

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
import java.io.IOException;
4242
import java.util.ArrayList;
4343
import java.util.List;
44+
import java.util.Set;
45+
import java.util.stream.Collectors;
46+
import java.util.stream.Stream;
4447

4548
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
4649
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
@@ -53,7 +56,19 @@
5356
abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
5457
private static final String VECTOR_FIELD = "vector";
5558
private static final String VECTOR_ALIAS_FIELD = "vector_alias";
56-
static final int VECTOR_DIMENSION = 3;
59+
protected final String indexType = indexType();
60+
protected final int VECTOR_DIMENSION = indexType.contains("bbq") ? 64 : 3;
61+
protected static final Set<String> QUANTIZED_INDEX_TYPES = Set.of(
62+
"int8_hnsw",
63+
"int4_hnsw",
64+
"bbq_hnsw",
65+
"int8_flat",
66+
"int4_flat",
67+
"bbq_flat"
68+
);
69+
protected static final Set<String> NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat");
70+
protected static final Set<String> ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream())
71+
.collect(Collectors.toUnmodifiableSet());
5772

5873
abstract DenseVectorFieldMapper.ElementType elementType();
5974

@@ -65,8 +80,15 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(
6580
Float similarity
6681
);
6782

83+
protected boolean isQuantizedElementType() {
84+
return QUANTIZED_INDEX_TYPES.contains(indexType());
85+
}
86+
87+
protected abstract String indexType();
88+
6889
@Override
6990
protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
91+
7092
XContentBuilder builder = XContentFactory.jsonBuilder()
7193
.startObject()
7294
.startObject("properties")
@@ -76,6 +98,9 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
7698
.field("index", true)
7799
.field("similarity", "l2_norm")
78100
.field("element_type", elementType())
101+
.startObject("index_options")
102+
.field("type", indexType)
103+
.endObject()
79104
.endObject()
80105
.startObject(VECTOR_ALIAS_FIELD)
81106
.field("type", "alias")
@@ -126,7 +151,7 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() {
126151

127152
@Override
128153
protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
129-
if (queryBuilder.rescoreVectorBuilder() != null) {
154+
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
130155
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
131156
query = rescoreQuery.innerQuery();
132157
}
@@ -154,7 +179,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
154179
// The field should always be resolved to the concrete field
155180
Integer k = queryBuilder.k();
156181
Integer numCands = queryBuilder.numCands();
157-
if (queryBuilder.rescoreVectorBuilder() != null) {
182+
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
158183
Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample();
159184
k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)));
160185
numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands);
@@ -330,19 +355,24 @@ public void testBWCVersionSerializationQuery() throws IOException {
330355

331356
public void testBWCVersionSerializationRescoreVector() throws IOException {
332357
KnnVectorQueryBuilder query = createTestQueryBuilder();
358+
TransportVersion version = TransportVersionUtils.randomVersionBetween(
359+
random(),
360+
TransportVersions.V_8_8_1,
361+
TransportVersionUtils.getPreviousVersion(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)
362+
);
363+
VectorData vectorData = version.onOrAfter(TransportVersions.V_8_14_0)
364+
? query.queryVector()
365+
: VectorData.fromFloats(query.queryVector().asFloatVector());
366+
Integer k = version.before(TransportVersions.V_8_15_0) ? null : query.k();
333367
KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder(
334368
query.getFieldName(),
335-
query.queryVector(),
336-
query.k(),
369+
vectorData,
370+
k,
337371
query.numCands(),
338372
null,
339373
query.getVectorSimilarity()
340374
).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
341-
assertBWCSerialization(
342-
query,
343-
queryNoRescoreVector,
344-
TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_8_0, TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)
345-
);
375+
assertBWCSerialization(query, queryNoRescoreVector, version);
346376
}
347377

348378
private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException {

server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,9 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(
3131
}
3232
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
3333
}
34+
35+
@Override
36+
protected String indexType() {
37+
return randomFrom(NON_QUANTIZED_INDEX_TYPES);
38+
}
3439
}

server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,9 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder(
3131
}
3232
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
3333
}
34+
35+
@Override
36+
protected String indexType() {
37+
return randomFrom(ALL_INDEX_TYPES);
38+
}
3439
}

0 commit comments

Comments
 (0)