99
1010package  org .elasticsearch .search .query ;
1111
12+ import  org .apache .lucene .util .VectorUtil ;
1213import  org .elasticsearch .cluster .metadata .IndexMetadata ;
1314import  org .elasticsearch .common .settings .Settings ;
1415import  org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
@@ -30,10 +31,11 @@ public class VectorIT extends ESIntegTestCase {
3031    private  static  final  String  VECTOR_FIELD  = "vector" ;
3132    private  static  final  String  NUM_ID_FIELD  = "num_id" ;
3233
33-     private  static  void  randomVector (float [] vector ) {
34+     private  static  void  randomVector (float [] vector ,  int   constant ) {
3435        for  (int  i  = 0 ; i  < vector .length ; i ++) {
35-             vector [i ] = randomFloat ();
36+             vector [i ] = randomFloat () *  constant ;
3637        }
38+         VectorUtil .l2normalize (vector );
3739    }
3840
3941    @ Before 
@@ -43,6 +45,7 @@ public void setup() throws IOException {
4345            .startObject ("properties" )
4446            .startObject (VECTOR_FIELD )
4547            .field ("type" , "dense_vector" )
48+             .field ("similarity" , "dot_product" )
4649            .startObject ("index_options" )
4750            .field ("type" , "hnsw" )
4851            .endObject ()
@@ -59,20 +62,21 @@ public void setup() throws IOException {
5962            .build ();
6063        prepareCreate (INDEX_NAME ).setMapping (mapping ).setSettings (settings ).get ();
6164        ensureGreen (INDEX_NAME );
65+         float [] vector  = new  float [16 ];
6266        for  (int  i  = 0 ; i  < 150 ; i ++) {
63-             float [] vector  = new  float [8 ];
64-             randomVector (vector );
67+             randomVector (vector , i  % 25  + 1 );
6568            prepareIndex (INDEX_NAME ).setId (Integer .toString (i )).setSource (VECTOR_FIELD , vector , NUM_ID_FIELD , i ).get ();
6669        }
6770        forceMerge (true );
6871        refresh (INDEX_NAME );
6972    }
7073
7174    public  void  testFilteredQueryStrategy () {
72-         float [] vector  = new  float [8 ];
73-         randomVector (vector );
75+         float [] vector  = new  float [16 ];
76+         randomVector (vector , 25 );
77+         int  upperLimit  = 35 ;
7478        var  query  = new  KnnSearchBuilder (VECTOR_FIELD , vector , 1 , 1 , null , null ).addFilterQuery (
75-             QueryBuilders .rangeQuery (NUM_ID_FIELD ).lte (30 )
79+             QueryBuilders .rangeQuery (NUM_ID_FIELD ).lte (35 )
7680        );
7781        assertResponse (client ().prepareSearch (INDEX_NAME ).setKnnSearch (List .of (query )).setSize (1 ).setProfile (true ), acornResponse  -> {
7882            assertNotEquals (0 , acornResponse .getHits ().getHits ().length );
@@ -116,6 +120,8 @@ public void testFilteredQueryStrategy() {
116120                assertTrue (
117121                    "fanoutVectorOps ["  + fanoutVectorOpsSum  + "] is not gt acornVectorOps ["  + vectorOpsSum  + "]" ,
118122                    fanoutVectorOpsSum  > vectorOpsSum 
123+                         // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit 
124+                         || (fanoutVectorOpsSum  == vectorOpsSum  && vectorOpsSum  == upperLimit  + 1 )
119125                );
120126            });
121127        });
0 commit comments