99
1010package org .elasticsearch .search .query ;
1111
12+ import org .apache .lucene .util .VectorUtil ;
1213import org .elasticsearch .cluster .metadata .IndexMetadata ;
1314import org .elasticsearch .common .settings .Settings ;
1415import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
@@ -30,10 +31,11 @@ public class VectorIT extends ESIntegTestCase {
3031 private static final String VECTOR_FIELD = "vector" ;
3132 private static final String NUM_ID_FIELD = "num_id" ;
3233
33- private static void randomVector (float [] vector ) {
34+ private static void randomVector (float [] vector , int constant ) {
3435 for (int i = 0 ; i < vector .length ; i ++) {
35- vector [i ] = randomFloat ();
36+ vector [i ] = randomFloat () * constant ;
3637 }
38+ VectorUtil .l2normalize (vector );
3739 }
3840
3941 @ Before
@@ -43,6 +45,7 @@ public void setup() throws IOException {
4345 .startObject ("properties" )
4446 .startObject (VECTOR_FIELD )
4547 .field ("type" , "dense_vector" )
48+ .field ("similarity" , "dot_product" )
4649 .startObject ("index_options" )
4750 .field ("type" , "hnsw" )
4851 .endObject ()
@@ -59,20 +62,21 @@ public void setup() throws IOException {
5962 .build ();
6063 prepareCreate (INDEX_NAME ).setMapping (mapping ).setSettings (settings ).get ();
6164 ensureGreen (INDEX_NAME );
65+ float [] vector = new float [16 ];
6266 for (int i = 0 ; i < 150 ; i ++) {
63- float [] vector = new float [8 ];
64- randomVector (vector );
67+ randomVector (vector , i % 25 + 1 );
6568 prepareIndex (INDEX_NAME ).setId (Integer .toString (i )).setSource (VECTOR_FIELD , vector , NUM_ID_FIELD , i ).get ();
6669 }
6770 forceMerge (true );
6871 refresh (INDEX_NAME );
6972 }
7073
7174 public void testFilteredQueryStrategy () {
72- float [] vector = new float [8 ];
73- randomVector (vector );
75+ float [] vector = new float [16 ];
76+ randomVector (vector , 25 );
77+ int upperLimit = 35 ;
7478 var query = new KnnSearchBuilder (VECTOR_FIELD , vector , 1 , 1 , null , null ).addFilterQuery (
75- QueryBuilders .rangeQuery (NUM_ID_FIELD ).lte (30 )
79+ QueryBuilders .rangeQuery (NUM_ID_FIELD ).lte (35 )
7680 );
7781 assertResponse (client ().prepareSearch (INDEX_NAME ).setKnnSearch (List .of (query )).setSize (1 ).setProfile (true ), acornResponse -> {
7882 assertNotEquals (0 , acornResponse .getHits ().getHits ().length );
@@ -116,6 +120,8 @@ public void testFilteredQueryStrategy() {
116120 assertTrue (
117121 "fanoutVectorOps [" + fanoutVectorOpsSum + "] is not gt acornVectorOps [" + vectorOpsSum + "]" ,
118122 fanoutVectorOpsSum > vectorOpsSum
123+ // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
124+ || (fanoutVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1 )
119125 );
120126 });
121127 });
0 commit comments