Skip to content

Commit ec3330e

Browse files
Attempt to use Java 22 DatasetImpl with MemorySegment
1 parent 7c85493 commit ec3330e

File tree

1 file changed

+39
-55
lines changed

1 file changed

+39
-55
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,13 @@
4242

4343
import java.io.IOException;
4444
import java.io.UncheckedIOException;
45+
import java.lang.foreign.Arena;
46+
import java.lang.foreign.MemorySegment;
4547
import java.nio.ByteBuffer;
4648
import java.nio.ByteOrder;
49+
import java.nio.channels.FileChannel;
50+
import java.nio.file.Path;
51+
import java.nio.file.StandardOpenOption;
4752
import java.util.ArrayList;
4853
import java.util.Arrays;
4954
import java.util.List;
@@ -460,36 +465,56 @@ public NodesIterator getNodesOnLevel(int level) {
460465
};
461466
}
462467

463-
// TODO check with deleted documents
464468
@Override
465469
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
466470
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
471+
int dims = fieldInfo.getVectorDimension();
467472
flatVectorWriter.mergeOneField(fieldInfo, mergeState);
468-
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
469-
// save merged vector values to a temp file
473+
FloatVectorValues mergeFloatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
474+
475+
if (mergeFloatVectorValues.size() < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
476+
// TODO: check how deleted documents affect size value
477+
KnnVectorValues.DocIndexIterator iter = mergeFloatVectorValues.iterator();
478+
float[] vector = new float[dims];
479+
List<float[]> vectorsList = new ArrayList<>();
480+
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
481+
System.arraycopy(mergeFloatVectorValues.vectorValue(iter.index()), 0, vector, 0, dims);
482+
vectorsList.add(vector);
483+
}
484+
float[][] vectors = vectorsList.toArray(new float[0][]);
485+
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
486+
writeFieldInternal(fieldInfo, datasetOrVectors);
487+
return;
488+
}
489+
490+
470491
final int numVectors;
471492
String tempRawVectorsFileName = null;
472493
boolean success = false;
494+
// save merged vectors to a temporary file
473495
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "vec_", IOContext.DEFAULT)) {
474496
tempRawVectorsFileName = out.getName();
475-
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
497+
numVectors = writeFloatVectorValues(fieldInfo, out, mergeFloatVectorValues);
476498
CodecUtil.writeFooter(out);
477499
success = true;
478500
} finally {
479501
if (success == false && tempRawVectorsFileName != null) {
480502
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
481503
}
482504
}
483-
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
484-
// TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
486-
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
487-
float[] vector;
488-
for (int i = 0; i < numVectors; i++) {
489-
vector = floatVectorValues.vectorValue(i);
490-
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
491-
}
492-
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
505+
// Use MemorySegment to map the temp file and pass it as a dataset for building the GPU index
506+
try {
507+
final Path path = ((org.apache.lucene.store.FSDirectory) mergeState.segmentInfo.dir).getDirectory().resolve(tempRawVectorsFileName);
508+
Arena arena = Arena.ofShared();
509+
FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ);
510+
final MemorySegment memorySegment = fileChannel.map(
511+
FileChannel.MapMode.READ_ONLY,
512+
0,
513+
fileChannel.size() - CodecUtil.footerLength(),
514+
arena
515+
);
516+
Dataset dataset = new DatasetImpl(arena, memorySegment, numVectors, fieldInfo.getVectorDimension());
517+
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(dataset, null);
493518
writeFieldInternal(fieldInfo, datasetOrVectors);
494519
} finally {
495520
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
@@ -511,47 +536,6 @@ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out,
511536
return numVectors;
512537
}
513538

514-
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
515-
if (numVectors == 0) {
516-
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
517-
}
518-
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
519-
final float[] vector = new float[fieldInfo.getVectorDimension()];
520-
return new FloatVectorValues() {
521-
@Override
522-
public float[] vectorValue(int ord) throws IOException {
523-
randomAccessInput.seek(ord * length + Integer.BYTES);
524-
randomAccessInput.readFloats(vector, 0, vector.length);
525-
return vector;
526-
}
527-
528-
@Override
529-
public FloatVectorValues copy() {
530-
return this;
531-
}
532-
533-
@Override
534-
public int dimension() {
535-
return fieldInfo.getVectorDimension();
536-
}
537-
538-
@Override
539-
public int size() {
540-
return numVectors;
541-
}
542-
543-
@Override
544-
public int ordToDoc(int ord) {
545-
try {
546-
randomAccessInput.seek(ord * length);
547-
return randomAccessInput.readInt();
548-
} catch (IOException e) {
549-
throw new UncheckedIOException(e);
550-
}
551-
}
552-
};
553-
}
554-
555539
private void writeMeta(
556540
FieldInfo field,
557541
long vectorIndexOffset,

0 commit comments

Comments
 (0)