4141import java .io .IOException ;
4242import java .util .ArrayList ;
4343import java .util .List ;
44+ import java .util .Set ;
45+ import java .util .stream .Collectors ;
46+ import java .util .stream .Stream ;
4447
4548import static org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .OVERSAMPLE_LIMIT ;
4649import static org .elasticsearch .search .SearchService .DEFAULT_SIZE ;
5356abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase <KnnVectorQueryBuilder > {
5457 private static final String VECTOR_FIELD = "vector" ;
5558 private static final String VECTOR_ALIAS_FIELD = "vector_alias" ;
56- static final int VECTOR_DIMENSION = 3 ;
59+ protected final String indexType = indexType ();
60+ protected final int VECTOR_DIMENSION = indexType .contains ("bbq" ) ? 64 : 3 ;
61+ protected static final Set <String > QUANTIZED_INDEX_TYPES = Set .of (
62+ "int8_hnsw" ,
63+ "int4_hnsw" ,
64+ "bbq_hnsw" ,
65+ "int8_flat" ,
66+ "int4_flat" ,
67+ "bbq_flat"
68+ );
69+ protected static final Set <String > NON_QUANTIZED_INDEX_TYPES = Set .of ("hnsw" , "flat" );
70+ protected static final Set <String > ALL_INDEX_TYPES = Stream .concat (QUANTIZED_INDEX_TYPES .stream (), NON_QUANTIZED_INDEX_TYPES .stream ())
71+ .collect (Collectors .toUnmodifiableSet ());
5772
5873 abstract DenseVectorFieldMapper .ElementType elementType ();
5974
@@ -65,8 +80,15 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(
6580 Float similarity
6681 );
6782
83+ protected boolean isQuantizedElementType () {
84+ return QUANTIZED_INDEX_TYPES .contains (indexType ());
85+ }
86+
87+ protected abstract String indexType ();
88+
6889 @ Override
6990 protected void initializeAdditionalMappings (MapperService mapperService ) throws IOException {
91+
7092 XContentBuilder builder = XContentFactory .jsonBuilder ()
7193 .startObject ()
7294 .startObject ("properties" )
@@ -76,6 +98,9 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
7698 .field ("index" , true )
7799 .field ("similarity" , "l2_norm" )
78100 .field ("element_type" , elementType ())
101+ .startObject ("index_options" )
102+ .field ("type" , indexType )
103+ .endObject ()
79104 .endObject ()
80105 .startObject (VECTOR_ALIAS_FIELD )
81106 .field ("type" , "alias" )
@@ -126,7 +151,7 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() {
126151
127152 @ Override
128153 protected void doAssertLuceneQuery (KnnVectorQueryBuilder queryBuilder , Query query , SearchExecutionContext context ) throws IOException {
129- if (queryBuilder .rescoreVectorBuilder () != null ) {
154+ if (queryBuilder .rescoreVectorBuilder () != null && isQuantizedElementType () ) {
130155 RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery ) query ;
131156 query = rescoreQuery .innerQuery ();
132157 }
@@ -154,7 +179,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
154179 // The field should always be resolved to the concrete field
155180 Integer k = queryBuilder .k ();
156181 Integer numCands = queryBuilder .numCands ();
157- if (queryBuilder .rescoreVectorBuilder () != null ) {
182+ if (queryBuilder .rescoreVectorBuilder () != null && isQuantizedElementType () ) {
158183 Float rescoreOversample = queryBuilder .rescoreVectorBuilder ().oversample ();
159184 k = k == null ? null : Integer .valueOf (Math .min (OVERSAMPLE_LIMIT , (int ) Math .ceil (k * rescoreOversample )));
160185 numCands = numCands == null ? null : Math .max (k == null ? 0 : k , numCands );
@@ -330,19 +355,24 @@ public void testBWCVersionSerializationQuery() throws IOException {
330355
331356 public void testBWCVersionSerializationRescoreVector () throws IOException {
332357 KnnVectorQueryBuilder query = createTestQueryBuilder ();
358+ TransportVersion version = TransportVersionUtils .randomVersionBetween (
359+ random (),
360+ TransportVersions .V_8_8_1 ,
361+ TransportVersionUtils .getPreviousVersion (TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE )
362+ );
363+ VectorData vectorData = version .onOrAfter (TransportVersions .V_8_14_0 )
364+ ? query .queryVector ()
365+ : VectorData .fromFloats (query .queryVector ().asFloatVector ());
366+ Integer k = version .before (TransportVersions .V_8_15_0 ) ? null : query .k ();
333367 KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder (
334368 query .getFieldName (),
335- query . queryVector () ,
336- query . k () ,
369+ vectorData ,
370+ k ,
337371 query .numCands (),
338372 null ,
339373 query .getVectorSimilarity ()
340374 ).queryName (query .queryName ()).boost (query .boost ()).addFilterQueries (query .filterQueries ());
341- assertBWCSerialization (
342- query ,
343- queryNoRescoreVector ,
344- TransportVersionUtils .randomVersionBetween (random (), TransportVersions .V_8_8_0 , TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE )
345- );
375+ assertBWCSerialization (query , queryNoRescoreVector , version );
346376 }
347377
348378 private void assertBWCSerialization (QueryBuilder newQuery , QueryBuilder bwcQuery , TransportVersion version ) throws IOException {
0 commit comments