Skip to content

Commit 7c85493

Browse files
Fix merging in GPU index writer
1 parent 4826da6 commit 7c85493

File tree

1 file changed

+134
-15
lines changed

1 file changed

+134
-15
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 134 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import com.nvidia.cuvs.CagraIndex;
1111
import com.nvidia.cuvs.CagraIndexParams;
1212
import com.nvidia.cuvs.CuVSResources;
13+
import com.nvidia.cuvs.Dataset;
1314

1415
import org.apache.lucene.codecs.CodecUtil;
1516
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
@@ -26,6 +27,7 @@
2627
import org.apache.lucene.index.Sorter;
2728
import org.apache.lucene.index.VectorEncoding;
2829
import org.apache.lucene.index.VectorSimilarityFunction;
30+
import org.apache.lucene.store.IOContext;
2931
import org.apache.lucene.store.IndexInput;
3032
import org.apache.lucene.store.IndexOutput;
3133
import org.apache.lucene.util.RamUsageEstimator;
@@ -39,6 +41,9 @@
3941
import org.elasticsearch.logging.Logger;
4042

4143
import java.io.IOException;
44+
import java.io.UncheckedIOException;
45+
import java.nio.ByteBuffer;
46+
import java.nio.ByteOrder;
4247
import java.util.ArrayList;
4348
import java.util.Arrays;
4449
import java.util.List;
@@ -166,9 +171,46 @@ public long ramBytesUsed() {
166171
return total;
167172
}
168173

174+
private static final class DatasetOrVectors {
175+
private final Dataset dataset;
176+
private final float[][] vectors;
177+
178+
DatasetOrVectors(float[][] vectors) {
179+
this(
180+
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : Dataset.ofArray(vectors),
181+
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null
182+
);
183+
validateState();
184+
}
185+
186+
private DatasetOrVectors(Dataset dataset, float[][] vectors) {
187+
this.dataset = dataset;
188+
this.vectors = vectors;
189+
validateState();
190+
}
191+
192+
private void validateState() {
193+
if ((dataset == null && vectors == null) || (dataset != null && vectors != null)) {
194+
throw new IllegalStateException("Exactly one of dataset or vectors must be non-null");
195+
}
196+
}
197+
198+
int size() {
199+
return dataset != null ? dataset.size() : vectors.length;
200+
}
201+
202+
Dataset getDataset() {
203+
return dataset;
204+
}
205+
206+
float[][] getVectors() {
207+
return vectors;
208+
}
209+
}
210+
169211
private void writeField(FieldWriter fieldWriter) throws IOException {
170212
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
171-
writeFieldInternal(fieldWriter.fieldInfo, vectors);
213+
writeFieldInternal(fieldWriter.fieldInfo, new DatasetOrVectors(vectors));
172214
}
173215

174216
private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
@@ -177,12 +219,13 @@ private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) thr
177219
throw new UnsupportedOperationException("Writing field with index sorted needs to be implemented.");
178220
}
179221

180-
private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws IOException {
222+
private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {
181223
try {
182224
long vectorIndexOffset = vectorIndex.getFilePointer();
183225
int[][] graphLevelNodeOffsets = new int[1][];
184226
HnswGraph mockGraph;
185-
if (vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
227+
if (datasetOrVectors.vectors != null) {
228+
float[][] vectors = datasetOrVectors.vectors;
186229
if (logger.isDebugEnabled()) {
187230
logger.debug(
188231
"Skip building carga index; vectors length {} < {} (min for GPU)",
@@ -192,12 +235,12 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I
192235
}
193236
mockGraph = writeGraph(vectors, graphLevelNodeOffsets);
194237
} else {
195-
String tempCagraHNSWFileName = buildGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors);
238+
String tempCagraHNSWFileName = buildGPUIndex(fieldInfo.getVectorSimilarityFunction(), datasetOrVectors.dataset);
196239
assert tempCagraHNSWFileName != null : "GPU index should be built for field: " + fieldInfo.name;
197240
mockGraph = writeGraph(tempCagraHNSWFileName, graphLevelNodeOffsets);
198241
}
199242
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
200-
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, vectors.length, mockGraph, graphLevelNodeOffsets);
243+
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetOrVectors.size(), mockGraph, graphLevelNodeOffsets);
201244
} catch (IOException e) {
202245
throw e;
203246
} catch (Throwable t) {
@@ -206,7 +249,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I
206249
}
207250

208251
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
209-
private String buildGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
252+
private String buildGPUIndex(VectorSimilarityFunction similarityFunction, Dataset dataset) throws Throwable {
210253
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
211254
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
212255
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
@@ -221,9 +264,9 @@ private String buildGPUIndex(VectorSimilarityFunction similarityFunction, float[
221264

222265
// build index on GPU
223266
long startTime = System.nanoTime();
224-
var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build();
267+
var index = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params).build();
225268
if (logger.isDebugEnabled()) {
226-
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length);
269+
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size());
227270
}
228271

229272
// TODO: do serialization through MemorySegment instead of a temp file
@@ -419,18 +462,94 @@ public NodesIterator getNodesOnLevel(int level) {
419462

420463
// TODO check with deleted documents
421464
@Override
465+
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
422466
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
423467
flatVectorWriter.mergeOneField(fieldInfo, mergeState);
424468
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
425-
// TODO: more efficient way to pass merged vector values to gpuIndex construction
426-
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
427-
List<float[]> vectorList = new ArrayList<>();
428-
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
429-
vectorList.add(vectorValues.vectorValue(iter.index()));
469+
// save merged vector values to a temp file
470+
final int numVectors;
471+
String tempRawVectorsFileName = null;
472+
boolean success = false;
473+
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "vec_", IOContext.DEFAULT)) {
474+
tempRawVectorsFileName = out.getName();
475+
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
476+
CodecUtil.writeFooter(out);
477+
success = true;
478+
} finally {
479+
if (success == false && tempRawVectorsFileName != null) {
480+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
481+
}
482+
}
483+
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
484+
// TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485+
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
486+
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
487+
float[] vector;
488+
for (int i = 0; i < numVectors; i++) {
489+
vector = floatVectorValues.vectorValue(i);
490+
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
491+
}
492+
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
493+
writeFieldInternal(fieldInfo, datasetOrVectors);
494+
} finally {
495+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
430496
}
431-
float[][] vectors = vectorList.toArray(new float[0][]);
497+
}
498+
499+
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
500+
throws IOException {
501+
int numVectors = 0;
502+
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
503+
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
504+
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
505+
numVectors++;
506+
float[] vector = floatVectorValues.vectorValue(iterator.index());
507+
out.writeInt(iterator.docID());
508+
buffer.asFloatBuffer().put(vector);
509+
out.writeBytes(buffer.array(), buffer.array().length);
510+
}
511+
return numVectors;
512+
}
513+
514+
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
515+
if (numVectors == 0) {
516+
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
517+
}
518+
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
519+
final float[] vector = new float[fieldInfo.getVectorDimension()];
520+
return new FloatVectorValues() {
521+
@Override
522+
public float[] vectorValue(int ord) throws IOException {
523+
randomAccessInput.seek(ord * length + Integer.BYTES);
524+
randomAccessInput.readFloats(vector, 0, vector.length);
525+
return vector;
526+
}
527+
528+
@Override
529+
public FloatVectorValues copy() {
530+
return this;
531+
}
432532

433-
writeFieldInternal(fieldInfo, vectors);
533+
@Override
534+
public int dimension() {
535+
return fieldInfo.getVectorDimension();
536+
}
537+
538+
@Override
539+
public int size() {
540+
return numVectors;
541+
}
542+
543+
@Override
544+
public int ordToDoc(int ord) {
545+
try {
546+
randomAccessInput.seek(ord * length);
547+
return randomAccessInput.readInt();
548+
} catch (IOException e) {
549+
throw new UncheckedIOException(e);
550+
}
551+
}
552+
};
434553
}
435554

436555
private void writeMeta(

0 commit comments

Comments
 (0)