Skip to content

Commit 723bb5a

Browse files
Revert "Attempt to use Java 22 DatasetImpl with MemorySegment"
This reverts commit ec3330e.
1 parent ec3330e commit 723bb5a

File tree

1 file changed

+55
-39
lines changed

1 file changed

+55
-39
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@
4242

4343
import java.io.IOException;
4444
import java.io.UncheckedIOException;
45-
import java.lang.foreign.Arena;
46-
import java.lang.foreign.MemorySegment;
4745
import java.nio.ByteBuffer;
4846
import java.nio.ByteOrder;
49-
import java.nio.channels.FileChannel;
50-
import java.nio.file.Path;
51-
import java.nio.file.StandardOpenOption;
5247
import java.util.ArrayList;
5348
import java.util.Arrays;
5449
import java.util.List;
@@ -465,56 +460,36 @@ public NodesIterator getNodesOnLevel(int level) {
465460
};
466461
}
467462

463+
// TODO check with deleted documents
468464
@Override
469465
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
470466
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
471-
int dims = fieldInfo.getVectorDimension();
472467
flatVectorWriter.mergeOneField(fieldInfo, mergeState);
473-
FloatVectorValues mergeFloatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
474-
475-
if (mergeFloatVectorValues.size() < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
476-
// TODO: check how deleted documents affect size value
477-
KnnVectorValues.DocIndexIterator iter = mergeFloatVectorValues.iterator();
478-
float[] vector = new float[dims];
479-
List<float[]> vectorsList = new ArrayList<>();
480-
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
481-
System.arraycopy(mergeFloatVectorValues.vectorValue(iter.index()), 0, vector, 0, dims);
482-
vectorsList.add(vector);
483-
}
484-
float[][] vectors = vectorsList.toArray(new float[0][]);
485-
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
486-
writeFieldInternal(fieldInfo, datasetOrVectors);
487-
return;
488-
}
489-
490-
468+
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
469+
// save merged vector values to a temp file
491470
final int numVectors;
492471
String tempRawVectorsFileName = null;
493472
boolean success = false;
494-
// save merged vectors to a temporary file
495473
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "vec_", IOContext.DEFAULT)) {
496474
tempRawVectorsFileName = out.getName();
497-
numVectors = writeFloatVectorValues(fieldInfo, out, mergeFloatVectorValues);
475+
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
498476
CodecUtil.writeFooter(out);
499477
success = true;
500478
} finally {
501479
if (success == false && tempRawVectorsFileName != null) {
502480
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
503481
}
504482
}
505-
// Use MemorySegment to map the temp file and pass it as a dataset for building the GPU index
506-
try {
507-
final Path path = ((org.apache.lucene.store.FSDirectory) mergeState.segmentInfo.dir).getDirectory().resolve(tempRawVectorsFileName);
508-
Arena arena = Arena.ofShared();
509-
FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ);
510-
final MemorySegment memorySegment = fileChannel.map(
511-
FileChannel.MapMode.READ_ONLY,
512-
0,
513-
fileChannel.size() - CodecUtil.footerLength(),
514-
arena
515-
);
516-
Dataset dataset = new DatasetImpl(arena, memorySegment, numVectors, fieldInfo.getVectorDimension());
517-
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(dataset, null);
483+
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
484+
// TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485+
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
486+
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
487+
float[] vector;
488+
for (int i = 0; i < numVectors; i++) {
489+
vector = floatVectorValues.vectorValue(i);
490+
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
491+
}
492+
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
518493
writeFieldInternal(fieldInfo, datasetOrVectors);
519494
} finally {
520495
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
@@ -536,6 +511,47 @@ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out,
536511
return numVectors;
537512
}
538513

514+
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
515+
if (numVectors == 0) {
516+
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
517+
}
518+
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
519+
final float[] vector = new float[fieldInfo.getVectorDimension()];
520+
return new FloatVectorValues() {
521+
@Override
522+
public float[] vectorValue(int ord) throws IOException {
523+
randomAccessInput.seek(ord * length + Integer.BYTES);
524+
randomAccessInput.readFloats(vector, 0, vector.length);
525+
return vector;
526+
}
527+
528+
@Override
529+
public FloatVectorValues copy() {
530+
return this;
531+
}
532+
533+
@Override
534+
public int dimension() {
535+
return fieldInfo.getVectorDimension();
536+
}
537+
538+
@Override
539+
public int size() {
540+
return numVectors;
541+
}
542+
543+
@Override
544+
public int ordToDoc(int ord) {
545+
try {
546+
randomAccessInput.seek(ord * length);
547+
return randomAccessInput.readInt();
548+
} catch (IOException e) {
549+
throw new UncheckedIOException(e);
550+
}
551+
}
552+
};
553+
}
554+
539555
private void writeMeta(
540556
FieldInfo field,
541557
long vectorIndexOffset,

0 commit comments

Comments
 (0)