Skip to content

Commit f1539fd

Browse files
GPU: Test feature disabling and document mutations (#138966)
- Add test for GPU-to-non-GPU transition (indexing on GPU, with continuous search and indexing on CPU) - Add test for deletes/updates - Support configurable similarity metrics in test setup
1 parent 42c2aa2 commit f1539fd

File tree

1 file changed

+125
-7
lines changed
  • x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/xpack/gpu

1 file changed

+125
-7
lines changed

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/xpack/gpu/GPUIndexIT.java

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.search.vectors.KnnSearchBuilder;
1919
import org.elasticsearch.search.vectors.VectorData;
2020
import org.elasticsearch.test.ESIntegTestCase;
21+
import org.junit.After;
2122
import org.junit.Assert;
2223
import org.junit.BeforeClass;
2324

@@ -35,17 +36,27 @@
3536
@LuceneTestCase.SuppressCodecs("*") // use our custom codec
3637
public class GPUIndexIT extends ESIntegTestCase {
3738

39+
private static boolean isGpuIndexingFeatureAllowed = true;
40+
41+
private String similarity;
42+
3843
public static class TestGPUPlugin extends GPUPlugin {
3944
public TestGPUPlugin() {
4045
super(Settings.builder().put("vectors.indexing.use_gpu", GpuMode.TRUE.name()).build());
4146
}
4247

4348
@Override
4449
protected boolean isGpuIndexingFeatureAllowed() {
45-
return true;
50+
return GPUIndexIT.isGpuIndexingFeatureAllowed;
4651
}
4752
}
4853

54+
@After
55+
public void reset() {
56+
isGpuIndexingFeatureAllowed = true;
57+
similarity = null;
58+
}
59+
4960
@Override
5061
protected Collection<Class<? extends Plugin>> nodePlugins() {
5162
return List.of(TestGPUPlugin.class);
@@ -70,6 +81,36 @@ public void testBasic() {
7081
assertSearch(indexName, randomFloatVector(dims), totalDocs);
7182
}
7283

84+
public void testSearchAndIndexAfterDisablingGpu() {
85+
String indexName = "index1";
86+
final int dims = randomIntBetween(4, 128);
87+
final int numDocs = randomIntBetween(1, 500);
88+
createIndex(indexName, dims, false);
89+
ensureGreen();
90+
91+
indexDocs(indexName, numDocs, dims, 0);
92+
refresh();
93+
94+
// Disable GPU usage via feature flag (simulating missing license)
95+
isGpuIndexingFeatureAllowed = false;
96+
ensureGreen();
97+
98+
assertAcked(indicesAdmin().prepareClose(indexName).get());
99+
assertAcked(indicesAdmin().prepareOpen(indexName).get());
100+
ensureGreen();
101+
102+
assertSearch(indexName, randomFloatVector(dims), numDocs);
103+
104+
// Add more data to the index
105+
final int additionalDocs = randomIntBetween(1, 100);
106+
indexDocs(indexName, additionalDocs, dims, numDocs);
107+
refresh();
108+
final int totalDocs = numDocs + additionalDocs;
109+
110+
// Perform another search with the additional data
111+
assertSearch(indexName, randomFloatVector(dims), totalDocs);
112+
}
113+
73114
public void testSortedIndexReturnsSameResultsAsUnsorted() {
74115
String indexName1 = "index_unsorted";
75116
String indexName2 = "index_sorted";
@@ -193,6 +234,58 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
193234
}
194235
}
195236

237+
public void testDeletesUpdates() {
238+
String indexName = "index_deletes_updates";
239+
final int dims = randomIntBetween(4, 128);
240+
createIndex(indexName, dims, false);
241+
242+
final int numDocs = randomIntBetween(700, 1000);
243+
indexDocs(indexName, numDocs, dims, 0);
244+
refresh();
245+
246+
// Perform random updates and deletes
247+
final int numOperations = randomIntBetween(10, 50);
248+
BulkRequestBuilder bulkRequest = client().prepareBulk();
249+
for (int i = 0; i < numOperations; i++) {
250+
int docId = randomIntBetween(0, numDocs - 1);
251+
if (randomBoolean()) {
252+
bulkRequest.add(
253+
prepareIndex(indexName).setId(String.valueOf(docId))
254+
.setSource("my_vector", randomFloatVector(dims), "my_keyword", String.valueOf(randomIntBetween(1, numDocs)))
255+
);
256+
} else {
257+
bulkRequest.add(client().prepareDelete(indexName, String.valueOf(docId)));
258+
}
259+
}
260+
BulkResponse bulkResponse = bulkRequest.get();
261+
assertFalse("Bulk request failed: " + bulkResponse.buildFailureMessage(), bulkResponse.hasFailures());
262+
refresh();
263+
264+
// Assert that approximate and exact searches return same sets of results
265+
float[] queryVector = randomFloatVector(dims);
266+
int k = 10;
267+
int numCandidates = k * 10;
268+
269+
var approxSearchResponse = prepareSearch(indexName).setSize(k)
270+
.setFetchSource(false)
271+
.setKnnSearch(List.of(new KnnSearchBuilder("my_vector", queryVector, k, numCandidates, null, null, null)))
272+
.get();
273+
274+
var exactSearchResponse = prepareSearch(indexName).setSize(k)
275+
.setFetchSource(false)
276+
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
277+
.get();
278+
279+
try {
280+
SearchHit[] approxHits = approxSearchResponse.getHits().getHits();
281+
SearchHit[] exactHits = exactSearchResponse.getHits().getHits();
282+
assertAtLeastNOutOfKMatches(approxHits, exactHits, k - 3, k);
283+
} finally {
284+
approxSearchResponse.decRef();
285+
exactSearchResponse.decRef();
286+
}
287+
}
288+
196289
public void testInt8HnswMaxInnerProductProductFails() {
197290
String indexName = "index_int8_max_inner_product_fails";
198291
final int dims = randomIntBetween(4, 128);
@@ -222,7 +315,7 @@ public void testInt8HnswMaxInnerProductProductFails() {
222315
// Attempt to index a document and expect it to fail
223316
IllegalArgumentException ex = expectThrows(
224317
IllegalArgumentException.class,
225-
() -> client().prepareIndex(indexName).setId("1").setSource("my_vector", randomFloatVector(dims)).get()
318+
() -> client().prepareIndex(indexName).setId("1").setSource("my_vector", randomNonUnitFloatVector(dims)).get()
226319
);
227320
assertThat(
228321
ex.getMessage(),
@@ -237,14 +330,18 @@ private void createIndex(String indexName, int dims, boolean sorted) {
237330
settings.put("index.sort.field", "my_keyword");
238331
}
239332

333+
if (similarity == null) {
334+
similarity = randomFrom("dot_product", "l2_norm", "cosine");
335+
}
336+
240337
String type = randomFrom("hnsw", "int8_hnsw");
241338
String mapping = String.format(Locale.ROOT, """
242339
{
243340
"properties": {
244341
"my_vector": {
245342
"type": "dense_vector",
246343
"dims": %d,
247-
"similarity": "l2_norm",
344+
"similarity": "%s",
248345
"index_options": {
249346
"type": "%s"
250347
}
@@ -254,7 +351,7 @@ private void createIndex(String indexName, int dims, boolean sorted) {
254351
}
255352
}
256353
}
257-
""", dims, type);
354+
""", dims, similarity, type);
258355
assertAcked(prepareCreate(indexName).setSettings(settings.build()).setMapping(mapping));
259356
ensureGreen();
260357
}
@@ -264,8 +361,8 @@ private void indexDocs(String indexName, int numDocs, int dims, int startDoc) {
264361
for (int i = 0; i < numDocs; i++) {
265362
String id = String.valueOf(startDoc + i);
266363
String keywordValue = String.valueOf(numDocs - i);
267-
var indexRequest = prepareIndex(indexName).setId(id)
268-
.setSource("my_vector", randomFloatVector(dims), "my_keyword", keywordValue);
364+
float[] vector = randomFloatVector(dims);
365+
var indexRequest = prepareIndex(indexName).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue);
269366
bulkRequest.add(indexRequest);
270367
}
271368
BulkResponse bulkResponse = bulkRequest.get();
@@ -284,14 +381,35 @@ private void assertSearch(String indexName, float[] queryVector, int totalDocs)
284381
);
285382
}
286383

287-
private static float[] randomFloatVector(int dims) {
384+
private static float[] randomNonUnitFloatVector(int dims) {
288385
float[] vector = new float[dims];
289386
for (int i = 0; i < dims; i++) {
290387
vector[i] = randomFloat();
291388
}
292389
return vector;
293390
}
294391

392+
private static float[] randomUnitVector(int dims) {
393+
float[] vector = new float[dims];
394+
double sumSquares = 0.0;
395+
for (int i = 0; i < dims; i++) {
396+
vector[i] = randomFloat() * 2 - 1; // Generate values between -1 and 1 for better distribution
397+
sumSquares += vector[i] * vector[i];
398+
}
399+
float magnitude = (float) Math.sqrt(sumSquares);
400+
if (magnitude > 0) {
401+
for (int i = 0; i < dims; i++) {
402+
vector[i] /= magnitude;
403+
}
404+
}
405+
return vector;
406+
}
407+
408+
private float[] randomFloatVector(int dims) {
409+
boolean useUnitVectors = "dot_product".equals(similarity);
410+
return useUnitVectors ? randomUnitVector(dims) : randomNonUnitFloatVector(dims);
411+
}
412+
295413
/**
296414
* Asserts that at least N out of K hits have matching IDs between two result sets.
297415
*/

0 commit comments

Comments
 (0)