Skip to content

Commit 8cee75e

Browse files
Build cagra index (iter1)
1 parent 2ac22b3 commit 8cee75e

File tree

5 files changed

+99
-19
lines changed

5 files changed

+99
-19
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,19 @@ static GPUVectorsReader getGPUReader(KnnVectorsReader vectorsReader, String fiel
7777
}
7878

7979
/** Tells whether the platform supports cuvs. */
80-
public static boolean supported() {
81-
try (var resources = CuVSResources.create()) {
82-
return true;
80+
public static CuVSResources cuVSResourcesOrNull() {
81+
try {
82+
var resources = CuVSResources.create();
83+
return resources;
8384
} catch (UnsupportedOperationException uoe) {
8485
var msg = uoe.getMessage() == null ? "" : ": " + uoe.getMessage();
85-
LOG.warn("cuvs is not supported on this platform or java version" + msg);
86+
LOG.warn("GPU based vector search is not supported on this platform or java version" + msg);
8687
} catch (Throwable t) {
8788
if (t instanceof ExceptionInInitializerError ex) {
8889
t = ex.getCause();
8990
}
9091
LOG.warn("Exception occurred during creation of cuvs resources. " + t);
9192
}
92-
return false;
93+
return null;
9394
}
9495
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsWriter.java

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
package org.elasticsearch.xpack.gpu.codec;
99

10+
import com.nvidia.cuvs.CagraIndex;
11+
import com.nvidia.cuvs.CagraIndexParams;
12+
import com.nvidia.cuvs.CuVSResources;
13+
1014
import org.apache.lucene.codecs.CodecUtil;
1115
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
1216
import org.apache.lucene.codecs.KnnVectorsWriter;
@@ -15,33 +19,46 @@
1519
import org.apache.lucene.index.FieldInfo;
1620
import org.apache.lucene.index.FloatVectorValues;
1721
import org.apache.lucene.index.IndexFileNames;
22+
import org.apache.lucene.index.KnnVectorValues;
1823
import org.apache.lucene.index.MergeState;
1924
import org.apache.lucene.index.SegmentWriteState;
2025
import org.apache.lucene.index.Sorter;
2126
import org.apache.lucene.index.VectorEncoding;
2227
import org.apache.lucene.index.VectorSimilarityFunction;
2328
import org.apache.lucene.store.IndexOutput;
29+
import org.elasticsearch.common.lucene.store.IndexOutputOutputStream;
2430
import org.elasticsearch.core.IOUtils;
31+
import org.elasticsearch.logging.LogManager;
32+
import org.elasticsearch.logging.Logger;
2533

2634
import java.io.IOException;
2735
import java.util.ArrayList;
2836
import java.util.List;
2937

3038
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
39+
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
3140

3241
/**
3342
* Writer for GPU-accelerated vectors.
3443
*/
3544
public class GPUVectorsWriter extends KnnVectorsWriter {
45+
private static final Logger logger = LogManager.getLogger(GPUVectorsWriter.class);
46+
// 2 for now based on https://github.com/rapidsai/cuvs/issues/666, but can be increased later
47+
private static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;
3648

3749
private final List<FieldWriter> fieldWriters = new ArrayList<>();
3850
private final IndexOutput gpuIdx;
3951
private final IndexOutput gpuMeta;
4052
private final FlatVectorsWriter rawVectorDelegate;
4153
private final SegmentWriteState segmentWriteState;
54+
private final CuVSResources cuVSResources;
4255

4356
@SuppressWarnings("this-escape")
4457
public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
58+
this.cuVSResources = GPUVectorsFormat.cuVSResourcesOrNull();
59+
if (cuVSResources == null) {
60+
throw new IllegalArgumentException("GPU based vector search is not supported on this platform or java version");
61+
}
4562
this.segmentWriteState = state;
4663
this.rawVectorDelegate = rawVectorDelegate;
4764
final String metaFileName = IndexFileNames.segmentFileName(
@@ -95,30 +112,89 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
95112
@Override
96113
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
97114
rawVectorDelegate.flush(maxDoc, sortMap);
115+
// TODO: implement the case when sortMap != null
116+
98117
for (FieldWriter fieldWriter : fieldWriters) {
99-
// TODO: Implement GPU-specific vector merging instead of bogus implementation
118+
// TODO: can we use MemorySegment instead of passing array of vectors
119+
float[][] vectors = fieldWriter.delegate.getVectors().toArray(float[][]::new);
100120
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
101-
var vectors = fieldWriter.delegate.getVectors();
102-
for (int i = 0; i < vectors.size(); i++) {
103-
gpuIdx.writeVInt(0);
121+
try {
122+
buildAndwriteGPUIndex(fieldWriter.fieldInfo.getVectorSimilarityFunction(), vectors);
123+
long dataLength = gpuIdx.getFilePointer() - dataOffset;
124+
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
125+
} catch (IOException e) {
126+
throw e;
127+
} catch (Throwable t) {
128+
throw new IOException("Failed to write GPU index: ", t);
129+
}
130+
}
131+
}
132+
133+
private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
134+
// TODO: should we Lucene HNSW index write here
135+
if (vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
136+
if (logger.isDebugEnabled()) {
137+
logger.debug("Skip building carga index; vectors length {} < {}", vectors.length, MIN_NUM_VECTORS_FOR_GPU_BUILD);
138+
}
139+
return;
140+
}
141+
142+
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
143+
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
144+
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
145+
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
146+
};
147+
148+
// TODO: expose cagra index params of intermediate graph degree, graph degre, algorithm, NNDescentNumIterations
149+
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
150+
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
151+
.withMetric(distanceType)
152+
.build();
153+
154+
// build index on GPU
155+
long startTime = System.nanoTime();
156+
var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build();
157+
if (logger.isDebugEnabled()) {
158+
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length);
159+
}
160+
161+
// TODO: do serialization through MemorySegment instead of a temp file
162+
// serialize index for CPU consumption
163+
startTime = System.nanoTime();
164+
var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx);
165+
try {
166+
index.serialize(gpuIndexOutputStream);
167+
if (logger.isDebugEnabled()) {
168+
logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
104169
}
105-
long dataLength = gpuIdx.getFilePointer() - dataOffset;
106-
writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength);
170+
} finally {
171+
index.destroyIndex();
107172
}
108173
}
109174

110175
@Override
111176
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
112177
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
113178
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
114-
// TODO: Implement GPU-specific vector merging instead of bogus implementation
115-
FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
179+
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
180+
// TODO: more efficient way to pass merged vector values to gpuIndex construction
181+
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
182+
List<float[]> vectorList = new ArrayList<>();
183+
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
184+
vectorList.add(vectorValues.vectorValue(iter.index()));
185+
}
186+
float[][] vectors = vectorList.toArray(new float[0][]);
187+
116188
long dataOffset = gpuIdx.alignFilePointer(Float.BYTES);
117-
for (int i = 0; i < floatVectorValues.size(); i++) {
118-
gpuIdx.writeVInt(0);
189+
try {
190+
buildAndwriteGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors);
191+
long dataLength = gpuIdx.getFilePointer() - dataOffset;
192+
writeMeta(fieldInfo, dataOffset, dataLength);
193+
} catch (IOException e) {
194+
throw e;
195+
} catch (Throwable t) {
196+
throw new IOException("Failed to write GPU index: ", t);
119197
}
120-
long dataLength = gpuIdx.getFilePointer() - dataOffset;
121-
writeMeta(fieldInfo, dataOffset, dataLength);
122198
} else {
123199
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
124200
}
@@ -157,6 +233,7 @@ public final void finish() throws IOException {
157233
@Override
158234
public final void close() throws IOException {
159235
IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx);
236+
cuVSResources.close();
160237
}
161238

162239
@Override
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
com.nvidia.cuvs:
2+
- load_native_libraries

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
2525

2626
@BeforeClass
2727
public static void beforeClass() {
28-
assumeTrue("cuvs not supported", GPUVectorsFormat.supported());
28+
assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull() != null);
2929
}
3030

3131
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());

x-pack/plugin/gpu/src/yamlRestTest/java/org/elasticsearch/xpack/gpu/GPUClientYamlTestSuiteIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class GPUClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
1919
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
2020
.nodes(1)
2121
.module("gpu")
22-
.setting("xpack.license.self_generated.type", "basic")
22+
.setting("xpack.license.self_generated.type", "trial")
2323
.setting("xpack.security.enabled", "false")
2424
.build();
2525

0 commit comments

Comments
 (0)