Skip to content

Commit 4f9b98a

Browse files
committed
Use memory-mapped MemorySegment from the vectorData IndexOutput/Input
1 parent 5a6ae3b commit 4f9b98a

File tree

2 files changed

+99
-33
lines changed

2 files changed

+99
-33
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.codecs.KnnVectorsWriter;
1717
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
1818
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
19+
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
1920
import org.apache.lucene.index.ByteVectorValues;
2021
import org.apache.lucene.index.DocsWithFieldSet;
2122
import org.apache.lucene.index.FieldInfo;
@@ -42,6 +43,7 @@
4243
import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
4344
import org.elasticsearch.logging.LogManager;
4445
import org.elasticsearch.logging.Logger;
46+
import org.elasticsearch.xpack.gpu.reflect.VectorsFormatReflectionUtils;
4547

4648
import java.io.IOException;
4749
import java.nio.ByteBuffer;
@@ -73,6 +75,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
7375
private final CuVSResourceManager cuVSResourceManager;
7476
private final SegmentWriteState segmentWriteState;
7577
private final IndexOutput meta, vectorIndex;
78+
private final IndexOutput vectorData;
7679
private final int M;
7780
private final int beamWidth;
7881
private final FlatVectorsWriter flatVectorWriter;
@@ -94,8 +97,11 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
9497
this.beamWidth = beamWidth;
9598
this.flatVectorWriter = flatVectorWriter;
9699
if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter) {
100+
vectorData = VectorsFormatReflectionUtils.getQuantizedVectorDataIndexOutput(flatVectorWriter);
97101
dataType = CuVSMatrix.DataType.BYTE;
98102
} else {
103+
assert flatVectorWriter instanceof Lucene99FlatVectorsWriter;
104+
vectorData = VectorsFormatReflectionUtils.getVectorDataIndexOutput(flatVectorWriter);
99105
dataType = CuVSMatrix.DataType.FLOAT;
100106
}
101107
this.segmentWriteState = state;
@@ -148,11 +154,38 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
148154
@Override
149155
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
150156
flatVectorWriter.flush(maxDoc, sortMap);
151-
for (FieldWriter field : fields) {
152-
if (sortMap == null) {
153-
writeField(field);
154-
} else {
155-
writeSortingField(field, sortMap);
157+
158+
try (IndexInput in = segmentWriteState.segmentInfo.dir.openInput(vectorData.getName(), IOContext.DEFAULT)) {
159+
var input = FilterIndexInput.unwrapOnlyTest(in);
160+
161+
for (FieldWriter fieldWriter : fields) {
162+
// TODO: is this inefficient? Can we get "size" in another way?
163+
var numVectors = fieldWriter.flatFieldVectorsWriter.getVectors().size();
164+
165+
final DatasetOrVectors datasetOrVectors;
166+
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
167+
// TODO: we are iterating over multiple fields, we probably need to memorySegmentAccessInput.segmentSliceOrNull()?
168+
var ds = DatasetUtils.getInstance()
169+
.fromInput(memorySegmentAccessInput, numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
170+
datasetOrVectors = DatasetOrVectors.fromDataset(ds);
171+
} else {
172+
var builder = CuVSMatrix.hostBuilder(numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
173+
for (var vector : fieldWriter.flatFieldVectorsWriter.getVectors()) {
174+
builder.addVector(vector);
175+
}
176+
177+
datasetOrVectors = DatasetOrVectors.fromDataset(builder.build());
178+
}
179+
180+
try {
181+
if (sortMap == null) {
182+
writeField(fieldWriter.fieldInfo, datasetOrVectors);
183+
} else {
184+
writeSortingField(fieldWriter.fieldInfo, datasetOrVectors, sortMap);
185+
}
186+
} finally {
187+
datasetOrVectors.close();
188+
}
156189
}
157190
}
158191
}
@@ -221,38 +254,13 @@ public void close() {
221254
}
222255
}
223256

224-
private void writeField(FieldWriter fieldWriter) throws IOException {
225-
var vectors = fieldWriter.flatFieldVectorsWriter.getVectors();
226-
final DatasetOrVectors datasetOrVectors;
227-
if (vectors.size() < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
228-
// Use vectors/CPU
229-
datasetOrVectors = DatasetOrVectors.fromArray(vectors);
230-
} else {
231-
// Avoid another heap copy (the float[][])
232-
233-
// TODO: another alternative is to use CuVSMatrix.deviceBuilder(), but this requires more effort
234-
// 1. support no-copy CuVSDeviceMatrix as input in CagraIndex
235-
// 2. ensure we are already holding a CuVSResource here
236-
var builder = CuVSMatrix.hostBuilder(vectors.size(), vectors.getFirst().length, dataType);
237-
for (var vector : vectors) {
238-
builder.addVector(vector);
239-
}
240-
datasetOrVectors = DatasetOrVectors.fromDataset(builder.build());
241-
}
242-
try {
243-
writeFieldInternal(fieldWriter.fieldInfo, datasetOrVectors);
244-
} finally {
245-
datasetOrVectors.close();
246-
}
247-
}
248-
249-
private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
257+
private void writeSortingField(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors, Sorter.DocMap sortMap) throws IOException {
250258
// The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap.
251259
// We can now treat them as a simple, sorted list of vectors.
252-
writeField(fieldData);
260+
writeField(fieldInfo, datasetOrVectors);
253261
}
254262

255-
private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {
263+
private void writeField(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {
256264
try {
257265
long vectorIndexOffset = vectorIndex.getFilePointer();
258266
int[][] graphLevelNodeOffsets = new int[1][];
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.reflect;
9+
10+
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
11+
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
12+
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
13+
import org.apache.lucene.store.IndexOutput;
14+
import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
15+
16+
import java.lang.invoke.MethodHandles;
17+
import java.lang.invoke.VarHandle;
18+
19+
public class VectorsFormatReflectionUtils {
20+
21+
private static final VarHandle FLAT_VECTOR_DATA_HANDLE;
22+
private static final VarHandle QUANTIZED_VECTOR_DATA_HANDLE;
23+
private static final VarHandle DELEGATE_WRITER_HANDLE;
24+
25+
static final Class<?> L99_SQ_VW_CLS = Lucene99ScalarQuantizedVectorsWriter.class;
26+
static final Class<?> L99_F_VW_CLS = Lucene99FlatVectorsWriter.class;
27+
static final Class<?> ES814_SQ_VW_CLS = ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter.class;
28+
29+
static {
30+
try {
31+
var lookup = MethodHandles.privateLookupIn(L99_F_VW_CLS, MethodHandles.lookup());
32+
FLAT_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_F_VW_CLS, "vectorData", IndexOutput.class);
33+
34+
lookup = MethodHandles.privateLookupIn(L99_SQ_VW_CLS, MethodHandles.lookup());
35+
QUANTIZED_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_SQ_VW_CLS, "quantizedVectorData", IndexOutput.class);
36+
37+
lookup = MethodHandles.privateLookupIn(ES814_SQ_VW_CLS, MethodHandles.lookup());
38+
DELEGATE_WRITER_HANDLE = lookup.findVarHandle(ES814_SQ_VW_CLS, "delegate", L99_SQ_VW_CLS);
39+
40+
} catch (IllegalAccessException e) {
41+
throw new AssertionError("should not happen, check opens", e);
42+
}
43+
catch (ReflectiveOperationException e) {
44+
throw new AssertionError(e);
45+
}
46+
}
47+
48+
public static IndexOutput getVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) {
49+
assert flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter;
50+
var delegate = (Lucene99ScalarQuantizedVectorsWriter)DELEGATE_WRITER_HANDLE.get(flatVectorWriter);
51+
return (IndexOutput) QUANTIZED_VECTOR_DATA_HANDLE.get(delegate);
52+
}
53+
54+
public static IndexOutput getQuantizedVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) {
55+
assert flatVectorWriter instanceof Lucene99FlatVectorsWriter;
56+
return (IndexOutput) FLAT_VECTOR_DATA_HANDLE.get(flatVectorWriter);
57+
}
58+
}

0 commit comments

Comments
 (0)