Skip to content

Commit 863b6e6

Browse files
committed
Fix tests for vector query builders to ensure multiple dimensions / index types can be used
1 parent fd9188f commit 863b6e6

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.test.TransportVersionUtils;
3838
import org.elasticsearch.xcontent.XContentBuilder;
3939
import org.elasticsearch.xcontent.XContentFactory;
40+
import org.junit.Before;
4041

4142
import java.io.IOException;
4243
import java.util.ArrayList;
@@ -56,8 +57,6 @@
5657
abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
5758
private static final String VECTOR_FIELD = "vector";
5859
private static final String VECTOR_ALIAS_FIELD = "vector_alias";
59-
protected final String indexType = indexType();
60-
protected final int VECTOR_DIMENSION = indexType.contains("bbq") ? 64 : 3;
6160
protected static final Set<String> QUANTIZED_INDEX_TYPES = Set.of(
6261
"int8_hnsw",
6362
"int4_hnsw",
@@ -69,6 +68,15 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
6968
protected static final Set<String> NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat");
7069
protected static final Set<String> ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream())
7170
.collect(Collectors.toUnmodifiableSet());
71+
protected static String indexType;
72+
protected static int vectorDimensions;
73+
74+
@Before
75+
private void checkIndexTypeAndDimensions() {
76+
// Check that these are initialized - should be done as part of the createAdditionalMappings method
77+
assertNotNull(indexType);
78+
assertNotEquals(0, vectorDimensions);
79+
}
7280

7381
abstract DenseVectorFieldMapper.ElementType elementType();
7482

@@ -81,20 +89,32 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(
8189
);
8290

8391
protected boolean isQuantizedElementType() {
84-
return QUANTIZED_INDEX_TYPES.contains(indexType());
92+
return QUANTIZED_INDEX_TYPES.contains(indexType);
8593
}
8694

87-
protected abstract String indexType();
95+
protected abstract String randomIndexType();
8896

8997
@Override
9098
protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
9199

100+
// These fields are initialized here, as mappings are initialized only once per test class.
101+
// We want the subclasses to be able to override the index type and vector dimensions so we don't make this static / BeforeClass
102+
// for initialization.
103+
indexType = randomIndexType();
104+
if (indexType.contains("bbq")) {
105+
vectorDimensions = 64;
106+
} else if (indexType.contains("int4")) {
107+
vectorDimensions = 4;
108+
} else {
109+
vectorDimensions = 3;
110+
}
111+
92112
XContentBuilder builder = XContentFactory.jsonBuilder()
93113
.startObject()
94114
.startObject("properties")
95115
.startObject(VECTOR_FIELD)
96116
.field("type", "dense_vector")
97-
.field("dims", VECTOR_DIMENSION)
117+
.field("dims", vectorDimensions)
98118
.field("index", true)
99119
.field("similarity", "l2_norm")
100120
.field("element_type", elementType())
@@ -201,7 +221,7 @@ public void testWrongDimension() {
201221
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
202222
assertThat(
203223
e.getMessage(),
204-
containsString("The query vector has a different number of dimensions [2] than the document vectors [3]")
224+
containsString("The query vector has a different number of dimensions [2] than the document vectors [" + vectorDimensions + "]")
205225
);
206226
}
207227

@@ -286,7 +306,7 @@ public void testMustRewrite() throws IOException {
286306
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(
287307
VECTOR_FIELD,
288308
new float[] { 1.0f, 2.0f, 3.0f },
289-
VECTOR_DIMENSION,
309+
vectorDimensions,
290310
null,
291311
null,
292312
null

server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(
2525
RescoreVectorBuilder rescoreVectorBuilder,
2626
Float similarity
2727
) {
28-
byte[] vector = new byte[VECTOR_DIMENSION];
28+
byte[] vector = new byte[vectorDimensions];
2929
for (int i = 0; i < vector.length; i++) {
3030
vector[i] = randomByte();
3131
}
3232
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
3333
}
3434

3535
@Override
36-
protected String indexType() {
36+
protected String randomIndexType() {
3737
return randomFrom(NON_QUANTIZED_INDEX_TYPES);
3838
}
3939
}

server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder(
2525
RescoreVectorBuilder rescoreVectorBuilder,
2626
Float similarity
2727
) {
28-
float[] vector = new float[VECTOR_DIMENSION];
28+
float[] vector = new float[vectorDimensions];
2929
for (int i = 0; i < vector.length; i++) {
3030
vector[i] = randomFloat();
3131
}
3232
return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity);
3333
}
3434

3535
@Override
36-
protected String indexType() {
36+
protected String randomIndexType() {
3737
return randomFrom(ALL_INDEX_TYPES);
3838
}
3939
}

0 commit comments

Comments
 (0)