1212import org .elasticsearch .common .settings .Settings ;
1313import org .elasticsearch .xcontent .XContentBuilder ;
1414import org .elasticsearch .xcontent .XContentFactory ;
15+ import org .elasticsearch .xpack .esql .EsqlTestUtils ;
1516import org .elasticsearch .xpack .esql .action .AbstractEsqlIntegTestCase ;
1617import org .junit .Before ;
1718
@@ -29,14 +30,34 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
2930
3031 public void testKnn () {
3132 var query = """
32- FROM test
33- | WHERE knn(vector, [1.0, 2.0, 3.0])
34- | KEEP id, floats
33+ FROM test METADATA _score
34+ | WHERE knn(vector, [1.0, 1.0, 1.0])
35+ | KEEP id, floats, _score, vector
36+ | SORT _score DESC
3537 """ ;
3638
3739 try (var resp = run (query )) {
38- assertColumnNames (resp .columns (), List .of ("id" , "floats" ));
39- assertColumnTypes (resp .columns (), List .of ("integer" , "double" ));
40+ assertColumnNames (resp .columns (), List .of ("id" , "floats" , "_score" , "vector" ));
41+ assertColumnTypes (resp .columns (), List .of ("integer" , "double" , "double" , "dense_vector" ));
42+
43+ List <List <Object >> valuesList = EsqlTestUtils .getValuesList (resp );
44+ assertEquals (indexedVectors .size (), valuesList .size ());
45+ for (int i = 0 ; i < valuesList .size (); i ++) {
46+ List <Object > row = valuesList .get (i );
47+ // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
48+ assertEquals (i , row .getFirst ());
49+ @ SuppressWarnings ("unchecked" )
50+ // Vectors should be the same
51+ List <Double > floats = (List <Double >)row .get (1 );
52+ for (int j = 0 ; j < floats .size (); j ++) {
53+ assertEquals (floats .get (j ).floatValue (), indexedVectors .get (i ).get (j ), 0f );
54+ }
55+ var score = (Double ) row .get (2 );
56+ assertNotNull (score );
57+ assertTrue (score > 0.0 );
58+ // dense_vector is null for now
59+ assertNull (row .get (3 ));
60+ }
4061 }
4162 }
4263
@@ -67,7 +88,7 @@ public void setup() throws IOException {
6788 var CreateRequest = client .prepareCreate (indexName ).setMapping (mapping ).setSettings (settingsBuilder .build ());
6889 assertAcked (CreateRequest );
6990
70- int numDocs = randomIntBetween ( 10 , 100 ) ;
91+ int numDocs = 10 ;
7192 int numDims = 3 ;
7293 IndexRequestBuilder [] docs = new IndexRequestBuilder [numDocs ];
7394 float value = 0.0f ;
@@ -76,7 +97,7 @@ public void setup() throws IOException {
7697 for (int j = 0 ; j < numDims ; j ++) {
7798 vector .add (value ++);
7899 }
79- docs [i ] = prepareIndex ("test" ).setId ("" + i ).setSource ("id" , String .valueOf (i ), "vector" , vector );
100+ docs [i ] = prepareIndex ("test" ).setId ("" + i ).setSource ("id" , String .valueOf (i ), "floats" , vector , " vector" , vector );
80101 indexedVectors .put (i , vector );
81102 }
82103
0 commit comments