1212import org .elasticsearch .action .bulk .BulkResponse ;
1313import org .elasticsearch .common .settings .Settings ;
1414import org .elasticsearch .plugins .Plugin ;
15+ import org .elasticsearch .search .SearchHit ;
1516import org .elasticsearch .search .vectors .KnnSearchBuilder ;
1617import org .elasticsearch .test .ESIntegTestCase ;
1718import org .elasticsearch .xpack .gpu .GPUPlugin ;
1819import org .elasticsearch .xpack .gpu .GPUSupport ;
20+ import org .junit .Assert ;
1921
2022import java .util .Collection ;
2123import java .util .List ;
@@ -34,40 +36,102 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
3436
3537 public void testBasic () {
3638 assumeTrue ("cuvs not supported" , GPUSupport .isSupported (false ));
39+ String indexName = "index1" ;
3740 final int dims = randomIntBetween (4 , 128 );
3841 final int [] numDocs = new int [] { randomIntBetween (1 , 100 ), 1 , 2 , randomIntBetween (1 , 100 ) };
39- createIndex (dims );
42+ createIndex (indexName , dims , false );
4043 int totalDocs = 0 ;
4144 for (int i = 0 ; i < numDocs .length ; i ++) {
42- indexDocs (numDocs [i ], dims , i * 100 );
45+ indexDocs (indexName , numDocs [i ], dims , i * 100 );
4346 totalDocs += numDocs [i ];
4447 }
4548 refresh ();
46- assertSearch (randomFloatVector (dims ), totalDocs );
49+ assertSearch (indexName , randomFloatVector (dims ), totalDocs );
50+ }
51+
52+ public void testSortedIndexReturnsSameResultsAsUnsorted () {
53+ assumeTrue ("cuvs not supported" , GPUSupport .isSupported (false ));
54+ String indexName1 = "index_unsorted" ;
55+ String indexName2 = "index_sorted" ;
56+ final int dims = randomIntBetween (4 , 128 );
57+ createIndex (indexName1 , dims , false );
58+ createIndex (indexName2 , dims , true );
59+
60+ final int [] numDocs = new int [] { randomIntBetween (50 , 100 ), randomIntBetween (50 , 100 ) };
61+ for (int i = 0 ; i < numDocs .length ; i ++) {
62+ BulkRequestBuilder bulkRequest1 = client ().prepareBulk ();
63+ BulkRequestBuilder bulkRequest2 = client ().prepareBulk ();
64+ for (int j = 0 ; j < numDocs [i ]; j ++) {
65+ String id = String .valueOf (i * 100 + j );
66+ String keywordValue = String .valueOf (numDocs [i ] - j );
67+ float [] vector = randomFloatVector (dims );
68+ bulkRequest1 .add (prepareIndex (indexName1 ).setId (id ).setSource ("my_vector" , vector , "my_keyword" , keywordValue ));
69+ bulkRequest2 .add (prepareIndex (indexName2 ).setId (id ).setSource ("my_vector" , vector , "my_keyword" , keywordValue ));
70+ }
71+ BulkResponse bulkResponse1 = bulkRequest1 .get ();
72+ assertFalse ("Bulk request failed: " + bulkResponse1 .buildFailureMessage (), bulkResponse1 .hasFailures ());
73+ BulkResponse bulkResponse2 = bulkRequest2 .get ();
74+ assertFalse ("Bulk request failed: " + bulkResponse2 .buildFailureMessage (), bulkResponse2 .hasFailures ());
75+ }
76+ refresh ();
77+
78+ float [] queryVector = randomFloatVector (dims );
79+ int k = 10 ;
80+ int numCandidates = k * 10 ;
81+
82+ var searchResponse1 = prepareSearch (indexName1 ).setSize (k )
83+ .setFetchSource (false )
84+ .addFetchField ("my_keyword" )
85+ .setKnnSearch (List .of (new KnnSearchBuilder ("my_vector" , queryVector , k , numCandidates , null , null )))
86+ .get ();
87+
88+ var searchResponse2 = prepareSearch (indexName2 ).setSize (k )
89+ .setFetchSource (false )
90+ .addFetchField ("my_keyword" )
91+ .setKnnSearch (List .of (new KnnSearchBuilder ("my_vector" , queryVector , k , numCandidates , null , null )))
92+ .get ();
93+
94+ try {
95+ SearchHit [] hits1 = searchResponse1 .getHits ().getHits ();
96+ SearchHit [] hits2 = searchResponse2 .getHits ().getHits ();
97+ Assert .assertEquals (hits1 .length , hits2 .length );
98+ for (int i = 0 ; i < hits1 .length ; i ++) {
99+ Assert .assertEquals (hits1 [i ].getId (), hits2 [i ].getId ());
100+ Assert .assertEquals ((String ) hits1 [i ].field ("my_keyword" ).getValue (), (String ) hits2 [i ].field ("my_keyword" ).getValue ());
101+ Assert .assertEquals (hits1 [i ].getScore (), hits2 [i ].getScore (), 0.0001f );
102+ }
103+ } finally {
104+ searchResponse1 .decRef ();
105+ searchResponse2 .decRef ();
106+ }
47107 }
48108
49109 public void testSearchWithoutGPU () {
50110 assumeTrue ("cuvs not supported" , GPUSupport .isSupported (false ));
111+ String indexName = "index1" ;
51112 final int dims = randomIntBetween (4 , 128 );
52113 final int numDocs = randomIntBetween (1 , 500 );
53- createIndex (dims );
114+ createIndex (indexName , dims , false );
54115 ensureGreen ();
55116
56- indexDocs (numDocs , dims , 0 );
117+ indexDocs (indexName , numDocs , dims , 0 );
57118 refresh ();
58119
59120 // update settings to disable GPU usage
60121 Settings .Builder settingsBuilder = Settings .builder ().put ("index.vectors.indexing.use_gpu" , false );
61- assertAcked (client ().admin ().indices ().prepareUpdateSettings ("foo-index" ).setSettings (settingsBuilder .build ()));
122+ assertAcked (client ().admin ().indices ().prepareUpdateSettings (indexName ).setSettings (settingsBuilder .build ()));
62123 ensureGreen ();
63- assertSearch (randomFloatVector (dims ), numDocs );
124+ assertSearch (indexName , randomFloatVector (dims ), numDocs );
64125 }
65126
66- private void createIndex (int dims ) {
127+ private void createIndex (String indexName , int dims , boolean sorted ) {
67128 var settings = Settings .builder ().put (indexSettings ());
68129 settings .put ("index.number_of_shards" , 1 );
69130 settings .put ("index.vectors.indexing.use_gpu" , true );
70- assertAcked (prepareCreate ("foo-index" ).setSettings (settings .build ()).setMapping (String .format (Locale .ROOT , """
131+ if (sorted ) {
132+ settings .put ("index.sort.field" , "my_keyword" );
133+ }
134+ assertAcked (prepareCreate (indexName ).setSettings (settings .build ()).setMapping (String .format (Locale .ROOT , """
71135 {
72136 "properties": {
73137 "my_vector": {
@@ -77,28 +141,36 @@ private void createIndex(int dims) {
77141 "index_options": {
78142 "type": "hnsw"
79143 }
144+ },
145+ "my_keyword": {
146+ "type": "keyword"
80147 }
81148 }
82149 }
83150 """ , dims )));
84151 ensureGreen ();
85152 }
86153
87- private void indexDocs (int numDocs , int dims , int startDoc ) {
154+ private void indexDocs (String indexName , int numDocs , int dims , int startDoc ) {
88155 BulkRequestBuilder bulkRequest = client ().prepareBulk ();
89156 for (int i = 0 ; i < numDocs ; i ++) {
90157 String id = String .valueOf (startDoc + i );
91- bulkRequest .add (prepareIndex ("foo-index" ).setId (id ).setSource ("my_vector" , randomFloatVector (dims )));
158+ String keywordValue = String .valueOf (numDocs - i );
159+ var indexRequest = prepareIndex (indexName ).setId (id )
160+ .setSource ("my_vector" , randomFloatVector (dims ), "my_keyword" , keywordValue );
161+ bulkRequest .add (indexRequest );
92162 }
93163 BulkResponse bulkResponse = bulkRequest .get ();
94164 assertFalse ("Bulk request failed: " + bulkResponse .buildFailureMessage (), bulkResponse .hasFailures ());
95165 }
96166
97- private void assertSearch (float [] queryVector , int totalDocs ) {
167+ private void assertSearch (String indexName , float [] queryVector , int totalDocs ) {
98168 int k = Math .min (randomIntBetween (1 , 20 ), totalDocs );
99169 int numCandidates = k * 10 ;
100170 assertNoFailuresAndResponse (
101- prepareSearch ("foo-index" ).setSize (k )
171+ prepareSearch (indexName ).setSize (k )
172+ .setFetchSource (false )
173+ .addFetchField ("my_keyword" )
102174 .setKnnSearch (List .of (new KnnSearchBuilder ("my_vector" , queryVector , k , numCandidates , null , null ))),
103175 response -> {
104176 assertEquals ("Expected k hits to be returned" , k , response .getHits ().getHits ().length );
0 commit comments