Skip to content

Commit dddadce

Browse files
Implement working with sorted index
1 parent e6e9c17 commit dddadce

File tree

2 files changed

+89
-16
lines changed

2 files changed

+89
-16
lines changed

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
import org.elasticsearch.action.bulk.BulkResponse;
1313
import org.elasticsearch.common.settings.Settings;
1414
import org.elasticsearch.plugins.Plugin;
15+
import org.elasticsearch.search.SearchHit;
1516
import org.elasticsearch.search.vectors.KnnSearchBuilder;
1617
import org.elasticsearch.test.ESIntegTestCase;
1718
import org.elasticsearch.xpack.gpu.GPUPlugin;
1819
import org.elasticsearch.xpack.gpu.GPUSupport;
20+
import org.junit.Assert;
1921

2022
import java.util.Collection;
2123
import java.util.List;
@@ -34,40 +36,102 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
3436

3537
public void testBasic() {
3638
assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
39+
String indexName = "index1";
3740
final int dims = randomIntBetween(4, 128);
3841
final int[] numDocs = new int[] { randomIntBetween(1, 100), 1, 2, randomIntBetween(1, 100) };
39-
createIndex(dims);
42+
createIndex(indexName, dims, false);
4043
int totalDocs = 0;
4144
for (int i = 0; i < numDocs.length; i++) {
42-
indexDocs(numDocs[i], dims, i * 100);
45+
indexDocs(indexName, numDocs[i], dims, i * 100);
4346
totalDocs += numDocs[i];
4447
}
4548
refresh();
46-
assertSearch(randomFloatVector(dims), totalDocs);
49+
assertSearch(indexName, randomFloatVector(dims), totalDocs);
50+
}
51+
52+
public void testSortedIndexReturnsSameResultsAsUnsorted() {
53+
assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
54+
String indexName1 = "index_unsorted";
55+
String indexName2 = "index_sorted";
56+
final int dims = randomIntBetween(4, 128);
57+
createIndex(indexName1, dims, false);
58+
createIndex(indexName2, dims, true);
59+
60+
final int[] numDocs = new int[] { randomIntBetween(50, 100), randomIntBetween(50, 100) };
61+
for (int i = 0; i < numDocs.length; i++) {
62+
BulkRequestBuilder bulkRequest1 = client().prepareBulk();
63+
BulkRequestBuilder bulkRequest2 = client().prepareBulk();
64+
for (int j = 0; j < numDocs[i]; j++) {
65+
String id = String.valueOf(i * 100 + j);
66+
String keywordValue = String.valueOf(numDocs[i] - j);
67+
float[] vector = randomFloatVector(dims);
68+
bulkRequest1.add(prepareIndex(indexName1).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue));
69+
bulkRequest2.add(prepareIndex(indexName2).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue));
70+
}
71+
BulkResponse bulkResponse1 = bulkRequest1.get();
72+
assertFalse("Bulk request failed: " + bulkResponse1.buildFailureMessage(), bulkResponse1.hasFailures());
73+
BulkResponse bulkResponse2 = bulkRequest2.get();
74+
assertFalse("Bulk request failed: " + bulkResponse2.buildFailureMessage(), bulkResponse2.hasFailures());
75+
}
76+
refresh();
77+
78+
float[] queryVector = randomFloatVector(dims);
79+
int k = 10;
80+
int numCandidates = k * 10;
81+
82+
var searchResponse1 = prepareSearch(indexName1).setSize(k)
83+
.setFetchSource(false)
84+
.addFetchField("my_keyword")
85+
.setKnnSearch(List.of(new KnnSearchBuilder("my_vector", queryVector, k, numCandidates, null, null)))
86+
.get();
87+
88+
var searchResponse2 = prepareSearch(indexName2).setSize(k)
89+
.setFetchSource(false)
90+
.addFetchField("my_keyword")
91+
.setKnnSearch(List.of(new KnnSearchBuilder("my_vector", queryVector, k, numCandidates, null, null)))
92+
.get();
93+
94+
try {
95+
SearchHit[] hits1 = searchResponse1.getHits().getHits();
96+
SearchHit[] hits2 = searchResponse2.getHits().getHits();
97+
Assert.assertEquals(hits1.length, hits2.length);
98+
for (int i = 0; i < hits1.length; i++) {
99+
Assert.assertEquals(hits1[i].getId(), hits2[i].getId());
100+
Assert.assertEquals((String) hits1[i].field("my_keyword").getValue(), (String) hits2[i].field("my_keyword").getValue());
101+
Assert.assertEquals(hits1[i].getScore(), hits2[i].getScore(), 0.0001f);
102+
}
103+
} finally {
104+
searchResponse1.decRef();
105+
searchResponse2.decRef();
106+
}
47107
}
48108

49109
public void testSearchWithoutGPU() {
50110
assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
111+
String indexName = "index1";
51112
final int dims = randomIntBetween(4, 128);
52113
final int numDocs = randomIntBetween(1, 500);
53-
createIndex(dims);
114+
createIndex(indexName, dims, false);
54115
ensureGreen();
55116

56-
indexDocs(numDocs, dims, 0);
117+
indexDocs(indexName, numDocs, dims, 0);
57118
refresh();
58119

59120
// update settings to disable GPU usage
60121
Settings.Builder settingsBuilder = Settings.builder().put("index.vectors.indexing.use_gpu", false);
61-
assertAcked(client().admin().indices().prepareUpdateSettings("foo-index").setSettings(settingsBuilder.build()));
122+
assertAcked(client().admin().indices().prepareUpdateSettings(indexName).setSettings(settingsBuilder.build()));
62123
ensureGreen();
63-
assertSearch(randomFloatVector(dims), numDocs);
124+
assertSearch(indexName, randomFloatVector(dims), numDocs);
64125
}
65126

66-
private void createIndex(int dims) {
127+
private void createIndex(String indexName, int dims, boolean sorted) {
67128
var settings = Settings.builder().put(indexSettings());
68129
settings.put("index.number_of_shards", 1);
69130
settings.put("index.vectors.indexing.use_gpu", true);
70-
assertAcked(prepareCreate("foo-index").setSettings(settings.build()).setMapping(String.format(Locale.ROOT, """
131+
if (sorted) {
132+
settings.put("index.sort.field", "my_keyword");
133+
}
134+
assertAcked(prepareCreate(indexName).setSettings(settings.build()).setMapping(String.format(Locale.ROOT, """
71135
{
72136
"properties": {
73137
"my_vector": {
@@ -77,28 +141,36 @@ private void createIndex(int dims) {
77141
"index_options": {
78142
"type": "hnsw"
79143
}
144+
},
145+
"my_keyword": {
146+
"type": "keyword"
80147
}
81148
}
82149
}
83150
""", dims)));
84151
ensureGreen();
85152
}
86153

87-
private void indexDocs(int numDocs, int dims, int startDoc) {
154+
private void indexDocs(String indexName, int numDocs, int dims, int startDoc) {
88155
BulkRequestBuilder bulkRequest = client().prepareBulk();
89156
for (int i = 0; i < numDocs; i++) {
90157
String id = String.valueOf(startDoc + i);
91-
bulkRequest.add(prepareIndex("foo-index").setId(id).setSource("my_vector", randomFloatVector(dims)));
158+
String keywordValue = String.valueOf(numDocs - i);
159+
var indexRequest = prepareIndex(indexName).setId(id)
160+
.setSource("my_vector", randomFloatVector(dims), "my_keyword", keywordValue);
161+
bulkRequest.add(indexRequest);
92162
}
93163
BulkResponse bulkResponse = bulkRequest.get();
94164
assertFalse("Bulk request failed: " + bulkResponse.buildFailureMessage(), bulkResponse.hasFailures());
95165
}
96166

97-
private void assertSearch(float[] queryVector, int totalDocs) {
167+
private void assertSearch(String indexName, float[] queryVector, int totalDocs) {
98168
int k = Math.min(randomIntBetween(1, 20), totalDocs);
99169
int numCandidates = k * 10;
100170
assertNoFailuresAndResponse(
101-
prepareSearch("foo-index").setSize(k)
171+
prepareSearch(indexName).setSize(k)
172+
.setFetchSource(false)
173+
.addFetchField("my_keyword")
102174
.setKnnSearch(List.of(new KnnSearchBuilder("my_vector", queryVector, k, numCandidates, null, null))),
103175
response -> {
104176
assertEquals("Expected k hits to be returned", k, response.getHits().getHits().length);

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,10 @@ private void writeField(FieldWriter fieldWriter) throws IOException {
221221
}
222222

223223
private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
224-
// TODO: implement writing sorted field when we can access cagra index through MemorySegment
225-
// as we need random access to neighbors in the graph.
226-
throw new UnsupportedOperationException("Writing field with index sorted needs to be implemented.");
224+
// The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap.
225+
// We can now treat them as a simple, sorted list of vectors.
226+
float[][] vectors = fieldData.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
227+
writeFieldInternal(fieldData.fieldInfo, DatasetOrVectors.fromArray(vectors));
227228
}
228229

229230
private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {

0 commit comments

Comments
 (0)