3737import org .elasticsearch .test .TransportVersionUtils ;
3838import org .elasticsearch .xcontent .XContentBuilder ;
3939import org .elasticsearch .xcontent .XContentFactory ;
40+ import org .junit .Before ;
4041
4142import java .io .IOException ;
4243import java .util .ArrayList ;
5657abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase <KnnVectorQueryBuilder > {
5758 private static final String VECTOR_FIELD = "vector" ;
5859 private static final String VECTOR_ALIAS_FIELD = "vector_alias" ;
59- protected final String indexType = indexType ();
60- protected final int VECTOR_DIMENSION = indexType .contains ("bbq" ) ? 64 : 3 ;
6160 protected static final Set <String > QUANTIZED_INDEX_TYPES = Set .of (
6261 "int8_hnsw" ,
6362 "int4_hnsw" ,
@@ -69,6 +68,15 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
6968 protected static final Set <String > NON_QUANTIZED_INDEX_TYPES = Set .of ("hnsw" , "flat" );
7069 protected static final Set <String > ALL_INDEX_TYPES = Stream .concat (QUANTIZED_INDEX_TYPES .stream (), NON_QUANTIZED_INDEX_TYPES .stream ())
7170 .collect (Collectors .toUnmodifiableSet ());
71+ protected static String indexType ;
72+ protected static int vectorDimensions ;
73+
74+ @ Before
75+ private void checkIndexTypeAndDimensions () {
76+ // Check that these are initialized - should be done as part of the createAdditionalMappings method
77+ assertNotNull (indexType );
78+ assertNotEquals (0 , vectorDimensions );
79+ }
7280
7381 abstract DenseVectorFieldMapper .ElementType elementType ();
7482
@@ -81,20 +89,32 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(
8189 );
8290
8391 protected boolean isQuantizedElementType () {
84- return QUANTIZED_INDEX_TYPES .contains (indexType () );
92+ return QUANTIZED_INDEX_TYPES .contains (indexType );
8593 }
8694
87- protected abstract String indexType ();
95+ protected abstract String randomIndexType ();
8896
8997 @ Override
9098 protected void initializeAdditionalMappings (MapperService mapperService ) throws IOException {
9199
100+ // These fields are initialized here, as mappings are initialized only once per test class.
101+ // We want the subclasses to be able to override the index type and vector dimensions so we don't make this static / BeforeClass
102+ // for initialization.
103+ indexType = randomIndexType ();
104+ if (indexType .contains ("bbq" )) {
105+ vectorDimensions = 64 ;
106+ } else if (indexType .contains ("int4" )) {
107+ vectorDimensions = 4 ;
108+ } else {
109+ vectorDimensions = 3 ;
110+ }
111+
92112 XContentBuilder builder = XContentFactory .jsonBuilder ()
93113 .startObject ()
94114 .startObject ("properties" )
95115 .startObject (VECTOR_FIELD )
96116 .field ("type" , "dense_vector" )
97- .field ("dims" , VECTOR_DIMENSION )
117+ .field ("dims" , vectorDimensions )
98118 .field ("index" , true )
99119 .field ("similarity" , "l2_norm" )
100120 .field ("element_type" , elementType ())
@@ -201,7 +221,7 @@ public void testWrongDimension() {
201221 IllegalArgumentException e = expectThrows (IllegalArgumentException .class , () -> query .doToQuery (context ));
202222 assertThat (
203223 e .getMessage (),
204- containsString ("The query vector has a different number of dimensions [2] than the document vectors [3 ]" )
224+ containsString ("The query vector has a different number of dimensions [2] than the document vectors [" + vectorDimensions + " ]" )
205225 );
206226 }
207227
@@ -286,7 +306,7 @@ public void testMustRewrite() throws IOException {
286306 KnnVectorQueryBuilder query = new KnnVectorQueryBuilder (
287307 VECTOR_FIELD ,
288308 new float [] { 1.0f , 2.0f , 3.0f },
289- VECTOR_DIMENSION ,
309+ vectorDimensions ,
290310 null ,
291311 null ,
292312 null
0 commit comments