1818import org .elasticsearch .search .vectors .KnnSearchBuilder ;
1919import org .elasticsearch .search .vectors .VectorData ;
2020import org .elasticsearch .test .ESIntegTestCase ;
21+ import org .junit .After ;
2122import org .junit .Assert ;
2223import org .junit .BeforeClass ;
2324
3536@ LuceneTestCase .SuppressCodecs ("*" ) // use our custom codec
3637public class GPUIndexIT extends ESIntegTestCase {
3738
39+ private static boolean isGpuIndexingFeatureAllowed = true ;
40+
41+ private String similarity ;
42+
3843 public static class TestGPUPlugin extends GPUPlugin {
3944 public TestGPUPlugin () {
4045 super (Settings .builder ().put ("vectors.indexing.use_gpu" , GpuMode .TRUE .name ()).build ());
4146 }
4247
4348 @ Override
4449 protected boolean isGpuIndexingFeatureAllowed () {
45- return true ;
50+ return GPUIndexIT . isGpuIndexingFeatureAllowed ;
4651 }
4752 }
4853
54+ @ After
55+ public void reset () {
56+ isGpuIndexingFeatureAllowed = true ;
57+ similarity = null ;
58+ }
59+
4960 @ Override
5061 protected Collection <Class <? extends Plugin >> nodePlugins () {
5162 return List .of (TestGPUPlugin .class );
@@ -70,6 +81,36 @@ public void testBasic() {
7081 assertSearch (indexName , randomFloatVector (dims ), totalDocs );
7182 }
7283
84+ public void testSearchAndIndexAfterDisablingGpu () {
85+ String indexName = "index1" ;
86+ final int dims = randomIntBetween (4 , 128 );
87+ final int numDocs = randomIntBetween (1 , 500 );
88+ createIndex (indexName , dims , false );
89+ ensureGreen ();
90+
91+ indexDocs (indexName , numDocs , dims , 0 );
92+ refresh ();
93+
94+ // Disable GPU usage via feature flag (simulating missing license)
95+ isGpuIndexingFeatureAllowed = false ;
96+ ensureGreen ();
97+
98+ assertAcked (indicesAdmin ().prepareClose (indexName ).get ());
99+ assertAcked (indicesAdmin ().prepareOpen (indexName ).get ());
100+ ensureGreen ();
101+
102+ assertSearch (indexName , randomFloatVector (dims ), numDocs );
103+
104+ // Add more data to the index
105+ final int additionalDocs = randomIntBetween (1 , 100 );
106+ indexDocs (indexName , additionalDocs , dims , numDocs );
107+ refresh ();
108+ final int totalDocs = numDocs + additionalDocs ;
109+
110+ // Perform another search with the additional data
111+ assertSearch (indexName , randomFloatVector (dims ), totalDocs );
112+ }
113+
73114 public void testSortedIndexReturnsSameResultsAsUnsorted () {
74115 String indexName1 = "index_unsorted" ;
75116 String indexName2 = "index_sorted" ;
@@ -193,6 +234,58 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
193234 }
194235 }
195236
237+ public void testDeletesUpdates () {
238+ String indexName = "index_deletes_updates" ;
239+ final int dims = randomIntBetween (4 , 128 );
240+ createIndex (indexName , dims , false );
241+
242+ final int numDocs = randomIntBetween (700 , 1000 );
243+ indexDocs (indexName , numDocs , dims , 0 );
244+ refresh ();
245+
246+ // Perform random updates and deletes
247+ final int numOperations = randomIntBetween (10 , 50 );
248+ BulkRequestBuilder bulkRequest = client ().prepareBulk ();
249+ for (int i = 0 ; i < numOperations ; i ++) {
250+ int docId = randomIntBetween (0 , numDocs - 1 );
251+ if (randomBoolean ()) {
252+ bulkRequest .add (
253+ prepareIndex (indexName ).setId (String .valueOf (docId ))
254+ .setSource ("my_vector" , randomFloatVector (dims ), "my_keyword" , String .valueOf (randomIntBetween (1 , numDocs )))
255+ );
256+ } else {
257+ bulkRequest .add (client ().prepareDelete (indexName , String .valueOf (docId )));
258+ }
259+ }
260+ BulkResponse bulkResponse = bulkRequest .get ();
261+ assertFalse ("Bulk request failed: " + bulkResponse .buildFailureMessage (), bulkResponse .hasFailures ());
262+ refresh ();
263+
264+ // Assert that approximate and exact searches return same sets of results
265+ float [] queryVector = randomFloatVector (dims );
266+ int k = 10 ;
267+ int numCandidates = k * 10 ;
268+
269+ var approxSearchResponse = prepareSearch (indexName ).setSize (k )
270+ .setFetchSource (false )
271+ .setKnnSearch (List .of (new KnnSearchBuilder ("my_vector" , queryVector , k , numCandidates , null , null , null )))
272+ .get ();
273+
274+ var exactSearchResponse = prepareSearch (indexName ).setSize (k )
275+ .setFetchSource (false )
276+ .setQuery (new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), "my_vector" , null ))
277+ .get ();
278+
279+ try {
280+ SearchHit [] approxHits = approxSearchResponse .getHits ().getHits ();
281+ SearchHit [] exactHits = exactSearchResponse .getHits ().getHits ();
282+ assertAtLeastNOutOfKMatches (approxHits , exactHits , k - 3 , k );
283+ } finally {
284+ approxSearchResponse .decRef ();
285+ exactSearchResponse .decRef ();
286+ }
287+ }
288+
196289 public void testInt8HnswMaxInnerProductProductFails () {
197290 String indexName = "index_int8_max_inner_product_fails" ;
198291 final int dims = randomIntBetween (4 , 128 );
@@ -222,7 +315,7 @@ public void testInt8HnswMaxInnerProductProductFails() {
222315 // Attempt to index a document and expect it to fail
223316 IllegalArgumentException ex = expectThrows (
224317 IllegalArgumentException .class ,
225- () -> client ().prepareIndex (indexName ).setId ("1" ).setSource ("my_vector" , randomFloatVector (dims )).get ()
318+ () -> client ().prepareIndex (indexName ).setId ("1" ).setSource ("my_vector" , randomNonUnitFloatVector (dims )).get ()
226319 );
227320 assertThat (
228321 ex .getMessage (),
@@ -237,14 +330,18 @@ private void createIndex(String indexName, int dims, boolean sorted) {
237330 settings .put ("index.sort.field" , "my_keyword" );
238331 }
239332
333+ if (similarity == null ) {
334+ similarity = randomFrom ("dot_product" , "l2_norm" , "cosine" );
335+ }
336+
240337 String type = randomFrom ("hnsw" , "int8_hnsw" );
241338 String mapping = String .format (Locale .ROOT , """
242339 {
243340 "properties": {
244341 "my_vector": {
245342 "type": "dense_vector",
246343 "dims": %d,
247- "similarity": "l2_norm ",
344+ "similarity": "%s ",
248345 "index_options": {
249346 "type": "%s"
250347 }
@@ -254,7 +351,7 @@ private void createIndex(String indexName, int dims, boolean sorted) {
254351 }
255352 }
256353 }
257- """ , dims , type );
354+ """ , dims , similarity , type );
258355 assertAcked (prepareCreate (indexName ).setSettings (settings .build ()).setMapping (mapping ));
259356 ensureGreen ();
260357 }
@@ -264,8 +361,8 @@ private void indexDocs(String indexName, int numDocs, int dims, int startDoc) {
264361 for (int i = 0 ; i < numDocs ; i ++) {
265362 String id = String .valueOf (startDoc + i );
266363 String keywordValue = String .valueOf (numDocs - i );
267- var indexRequest = prepareIndex ( indexName ). setId ( id )
268- . setSource ("my_vector" , randomFloatVector ( dims ) , "my_keyword" , keywordValue );
364+ float [] vector = randomFloatVector ( dims );
365+ var indexRequest = prepareIndex ( indexName ). setId ( id ). setSource ("my_vector" , vector , "my_keyword" , keywordValue );
269366 bulkRequest .add (indexRequest );
270367 }
271368 BulkResponse bulkResponse = bulkRequest .get ();
@@ -284,14 +381,35 @@ private void assertSearch(String indexName, float[] queryVector, int totalDocs)
284381 );
285382 }
286383
287- private static float [] randomFloatVector (int dims ) {
384+ private static float [] randomNonUnitFloatVector (int dims ) {
288385 float [] vector = new float [dims ];
289386 for (int i = 0 ; i < dims ; i ++) {
290387 vector [i ] = randomFloat ();
291388 }
292389 return vector ;
293390 }
294391
392+ private static float [] randomUnitVector (int dims ) {
393+ float [] vector = new float [dims ];
394+ double sumSquares = 0.0 ;
395+ for (int i = 0 ; i < dims ; i ++) {
396+ vector [i ] = randomFloat () * 2 - 1 ; // Generate values between -1 and 1 for better distribution
397+ sumSquares += vector [i ] * vector [i ];
398+ }
399+ float magnitude = (float ) Math .sqrt (sumSquares );
400+ if (magnitude > 0 ) {
401+ for (int i = 0 ; i < dims ; i ++) {
402+ vector [i ] /= magnitude ;
403+ }
404+ }
405+ return vector ;
406+ }
407+
408+ private float [] randomFloatVector (int dims ) {
409+ boolean useUnitVectors = "dot_product" .equals (similarity );
410+ return useUnitVectors ? randomUnitVector (dims ) : randomNonUnitFloatVector (dims );
411+ }
412+
295413 /**
296414 * Asserts that at least N out of K hits have matching IDs between two result sets.
297415 */
0 commit comments