1010import org .elasticsearch .action .index .IndexRequestBuilder ;
1111import org .elasticsearch .cluster .metadata .IndexMetadata ;
1212import org .elasticsearch .common .settings .Settings ;
13+ import org .elasticsearch .search .vectors .KnnVectorQueryBuilder ;
1314import org .elasticsearch .xcontent .XContentBuilder ;
1415import org .elasticsearch .xcontent .XContentFactory ;
1516import org .elasticsearch .xpack .esql .EsqlTestUtils ;
1819
1920import java .io .IOException ;
2021import java .util .ArrayList ;
22+ import java .util .Arrays ;
2123import java .util .HashMap ;
2224import java .util .List ;
25+ import java .util .Locale ;
2326import java .util .Map ;
2427
2528import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .assertAcked ;
2629
2730public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
2831
2932 private final Map <Integer , List <Float >> indexedVectors = new HashMap <>();
33+ private int numDocs ;
34+ private int numDims ;
3035
3136 public void testKnnDefaults () {
32- var query = """
37+ float [] queryVector = new float [numDims ];
38+ Arrays .fill (queryVector , 1.0f );
39+
40+ var query = String .format (Locale .ROOT , """
3341 FROM test METADATA _score
34- | WHERE knn(vector, [1.0, 1.0, 1.0] )
42+ | WHERE knn(vector, %s )
3543 | KEEP id, floats, _score, vector
3644 | SORT _score DESC
37- """ ;
45+ """ , Arrays . toString ( queryVector )) ;
3846
3947 try (var resp = run (query )) {
4048 assertColumnNames (resp .columns (), List .of ("id" , "floats" , "_score" , "vector" ));
4149 assertColumnTypes (resp .columns (), List .of ("integer" , "double" , "double" , "dense_vector" ));
4250
4351 List <List <Object >> valuesList = EsqlTestUtils .getValuesList (resp );
44- assertEquals (indexedVectors .size (), valuesList .size ());
52+ assertEquals (Math . min ( indexedVectors .size (), 10 ), valuesList .size ());
4553 for (int i = 0 ; i < valuesList .size (); i ++) {
4654 List <Object > row = valuesList .get (i );
4755 // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
@@ -62,12 +70,15 @@ public void testKnnDefaults() {
6270 }
6371
6472 public void testKnnOptions () {
65- var query = """
73+ float [] queryVector = new float [numDims ];
74+ Arrays .fill (queryVector , 1.0f );
75+
76+ var query = String .format (Locale .ROOT , """
6677 FROM test METADATA _score
67- | WHERE knn(vector, [1.0, 1.0, 1.0] , {"k": 5})
78+ | WHERE knn(vector, %s , {"k": 5})
6879 | KEEP id, floats, _score, vector
6980 | SORT _score DESC
70- """ ;
81+ """ , Arrays . toString ( queryVector )) ;
7182
7283 try (var resp = run (query )) {
7384 assertColumnNames (resp .columns (), List .of ("id" , "floats" , "_score" , "vector" ));
@@ -79,20 +90,24 @@ public void testKnnOptions() {
7990 }
8091
8192 public void testKnnNonPushedDown () {
82- var query = """
93+ float [] queryVector = new float [numDims ];
94+ Arrays .fill (queryVector , 1.0f );
95+
96+ // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
97+ var query = String .format (Locale .ROOT , """
8398 FROM test METADATA _score
84- | WHERE knn(vector, [1.0, 1.0, 1.0], {"k": 5}) OR id % 2 == 0
99+ | WHERE knn(vector, %s, {"k": 5}) OR id > 10
85100 | KEEP id, floats, _score, vector
86101 | SORT _score DESC
87- """ ;
102+ """ , Arrays . toString ( queryVector )) ;
88103
89104 try (var resp = run (query )) {
90105 assertColumnNames (resp .columns (), List .of ("id" , "floats" , "_score" , "vector" ));
91106 assertColumnTypes (resp .columns (), List .of ("integer" , "double" , "double" , "dense_vector" ));
92107
93108 List <List <Object >> valuesList = EsqlTestUtils .getValuesList (resp );
94- // K = 5, 2 more for % operator, total 7
95- assertEquals (7 , valuesList .size ());
109+ // K = 5, 1 more for every id > 10
110+ assertEquals (5 + Math . max ( 0 , numDocs - 10 - 1 ) , valuesList .size ());
96111 }
97112 }
98113
@@ -120,11 +135,11 @@ public void setup() throws IOException {
120135 .put (IndexMetadata .SETTING_NUMBER_OF_REPLICAS , 0 )
121136 .put (IndexMetadata .SETTING_NUMBER_OF_SHARDS , 1 );
122137
123- var CreateRequest = client .prepareCreate (indexName ).setMapping (mapping ).setSettings (settingsBuilder .build ());
124- assertAcked (CreateRequest );
138+ var createRequest = client .prepareCreate (indexName ).setMapping (mapping ).setSettings (settingsBuilder .build ());
139+ assertAcked (createRequest );
125140
126- int numDocs = 10 ;
127- int numDims = 3 ;
141+ numDocs = randomIntBetween ( 10 , 20 ) ;
142+ numDims = randomIntBetween ( 3 , 10 ) ;
128143 IndexRequestBuilder [] docs = new IndexRequestBuilder [numDocs ];
129144 float value = 0.0f ;
130145 for (int i = 0 ; i < numDocs ; i ++) {
0 commit comments