Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ public ESGpuHnswSQVectorsFormat(int maxConn, int beamWidth, Float confidenceInte

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ESGpuHnswVectorsWriter(cuVSResourceManager, state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state));
return new ESGpuHnswVectorsWriter(
cuVSResourceManager,
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
flatVectorsFormat::fieldsReader
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ESGpuHnswVectorsFormat() {

public ESGpuHnswVectorsFormat(int maxConn, int beamWidth) {
this(CuVSResourceManager.pooling(), maxConn, beamWidth);
};
}

public ESGpuHnswVectorsFormat(CuVSResourceManager cuVSResourceManager, int maxConn, int beamWidth) {
super(NAME);
Expand All @@ -66,7 +66,14 @@ public ESGpuHnswVectorsFormat(CuVSResourceManager cuVSResourceManager, int maxCo

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ESGpuHnswVectorsWriter(cuVSResourceManager, state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state));
return new ESGpuHnswVectorsWriter(
cuVSResourceManager,
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state),
flatVectorsFormat::fieldsReader
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
Expand All @@ -37,6 +42,7 @@
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
Expand Down Expand Up @@ -78,6 +84,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
private final FlatVectorsWriter flatVectorWriter;

private final List<FieldWriter> fields = new ArrayList<>();
private final CheckedFunction<SegmentReadState, FlatVectorsReader, IOException> flatVectorsReaderProvider;
private boolean finished;
private final CuVSMatrix.DataType dataType;

Expand All @@ -86,8 +93,10 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
SegmentWriteState state,
int M,
int beamWidth,
FlatVectorsWriter flatVectorWriter
FlatVectorsWriter flatVectorWriter,
CheckedFunction<SegmentReadState, FlatVectorsReader, IOException> flatVectorsReaderProvider
) throws IOException {
this.flatVectorsReaderProvider = flatVectorsReaderProvider;
assert cuVSResourceManager != null : "CuVSResources must not be null";
this.cuVSResourceManager = cuVSResourceManager;
this.M = M;
Expand All @@ -96,6 +105,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter) {
dataType = CuVSMatrix.DataType.BYTE;
} else {
assert flatVectorWriter instanceof Lucene99FlatVectorsWriter;
dataType = CuVSMatrix.DataType.FLOAT;
}
this.segmentWriteState = state;
Expand Down Expand Up @@ -145,14 +155,84 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
return newField;
}

private static MemorySegmentAccessInput getMemorySegmentAccessInputOrNull(KnnVectorValues vectorValues) {

if (vectorValues instanceof HasIndexSlice indexSlice) {
var input = FilterIndexInput.unwrapOnlyTest(indexSlice.getSlice());
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
return memorySegmentAccessInput;
}
}
return null;
}

@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
flatVectorWriter.flush(maxDoc, sortMap);
for (FieldWriter field : fields) {
if (sortMap == null) {
writeField(field);
} else {
writeSortingField(field, sortMap);

// TODO: this "mimics" a hypothetical/missing FlatVectorsWriter#getReader()
try (
FlatVectorsReader flatVectorsReader = flatVectorsReaderProvider.apply(
new SegmentReadState(
segmentWriteState.segmentInfo.dir,
segmentWriteState.segmentInfo,
new FieldInfos(fields.stream().map(x -> x.fieldInfo).toArray(FieldInfo[]::new)),
IOContext.DEFAULT,
segmentWriteState.segmentSuffix
)
)
) {
for (FieldWriter fieldWriter : fields) {
// This might be inefficient if getVectors() materializes a List<T>; however current implementations
// just return a reference to an inner, already allocated List<T>, so we are fine for now.
// TODO: change when/if Lucene introduces a direct FlatFieldVectorsWriter<T>#size()
var numVectors = fieldWriter.flatFieldVectorsWriter.getVectors().size();

final CuVSMatrix dataset;
if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
if (dataType == CuVSMatrix.DataType.FLOAT) {
FloatVectorValues floatVectorValues = flatVectorsReader.getFloatVectorValues(fieldWriter.fieldInfo.name);
var memorySegmentAccessInput = getMemorySegmentAccessInputOrNull(floatVectorValues);
if (memorySegmentAccessInput != null) {
dataset = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
} else {
var builder = CuVSMatrix.hostBuilder(numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
for (int i = 0; i < numVectors; ++i) {
builder.addVector(floatVectorValues.vectorValue(i));
}
dataset = builder.build();
}
} else {
assert dataType == CuVSMatrix.DataType.BYTE;
ByteVectorValues byteVectorValues = flatVectorsReader.getByteVectorValues(fieldWriter.fieldInfo.name);
var memorySegmentAccessInput = getMemorySegmentAccessInputOrNull(byteVectorValues);
if (memorySegmentAccessInput != null) {
dataset = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
} else {
var builder = CuVSMatrix.hostBuilder(numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType);
for (int i = 0; i < numVectors; ++i) {
builder.addVector(byteVectorValues.vectorValue(i));
}
dataset = builder.build();
}
}
} else {
dataset = null;
}

try {
if (sortMap == null) {
writeField(fieldWriter.fieldInfo, dataset, numVectors);
} else {
writeSortingField(fieldWriter.fieldInfo, dataset, numVectors, sortMap);
}
} finally {
if (dataset != null) {
dataset.close();
}
}
}
}
}
Expand Down Expand Up @@ -185,83 +265,35 @@ public long ramBytesUsed() {
return total;
}

private static final class DatasetOrVectors {
private final CuVSMatrix dataset;
private final float[][] vectors;

static DatasetOrVectors fromArray(float[][] vectors) {
return new DatasetOrVectors(
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : CuVSMatrix.ofArray(vectors),
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null
);
}

static DatasetOrVectors fromDataset(CuVSMatrix dataset) {
return new DatasetOrVectors(dataset, null);
}

private DatasetOrVectors(CuVSMatrix dataset, float[][] vectors) {
this.dataset = dataset;
this.vectors = vectors;
validateState();
}

private void validateState() {
if ((dataset == null && vectors == null) || (dataset != null && vectors != null)) {
throw new IllegalStateException("Exactly one of dataset or vectors must be non-null");
}
}

int size() {
return dataset != null ? (int) dataset.size() : vectors.length;
}

CuVSMatrix getDataset() {
return dataset;
}

float[][] getVectors() {
return vectors;
}
}

private void writeField(FieldWriter fieldWriter) throws IOException {
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
writeFieldInternal(fieldWriter.fieldInfo, DatasetOrVectors.fromArray(vectors));
}

private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
private void writeSortingField(FieldInfo fieldInfo, CuVSMatrix datasetOrVectors, int size, Sorter.DocMap sortMap) throws IOException {
// The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap.
// We can now treat them as a simple, sorted list of vectors.
float[][] vectors = fieldData.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
writeFieldInternal(fieldData.fieldInfo, DatasetOrVectors.fromArray(vectors));
writeField(fieldInfo, datasetOrVectors, size);
}

private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {
private void writeField(FieldInfo fieldInfo, CuVSMatrix dataset, int size) throws IOException {
try {
long vectorIndexOffset = vectorIndex.getFilePointer();
int[][] graphLevelNodeOffsets = new int[1][];
HnswGraph mockGraph;
if (datasetOrVectors.getVectors() != null) {
int size = datasetOrVectors.size();
final HnswGraph graph;
if (dataset == null) {
if (logger.isDebugEnabled()) {
logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", size, MIN_NUM_VECTORS_FOR_GPU_BUILD);
}
mockGraph = writeGraph(size, graphLevelNodeOffsets);
graph = writeMockGraph(size, graphLevelNodeOffsets);
} else {
var dataset = datasetOrVectors.getDataset();
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType());
try {
try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
assert index != null : "GPU index should be built for field: " + fieldInfo.name;
mockGraph = writeGraph(index.getGraph(), graphLevelNodeOffsets);
graph = writeGraph(index.getGraph(), graphLevelNodeOffsets);
}
} finally {
cuVSResourceManager.release(cuVSResources);
}
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetOrVectors.size(), mockGraph, graphLevelNodeOffsets);
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, size, graph, graphLevelNodeOffsets);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
Expand Down Expand Up @@ -337,7 +369,7 @@ private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) th
}

// create a mock graph where every node is connected to every other node
private HnswGraph writeGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
private HnswGraph writeMockGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
if (elementCount == 0) {
return null;
}
Expand Down Expand Up @@ -435,20 +467,43 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
}
}
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
DatasetOrVectors datasetOrVectors;
var input = FilterIndexInput.unwrapOnlyTest(in);
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
var ds = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
datasetOrVectors = DatasetOrVectors.fromDataset(ds);

final CuVSMatrix dataset;
if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
// Direct access to mmapped file
dataset = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
} else {
var builder = CuVSMatrix.hostBuilder(numVectors, fieldInfo.getVectorDimension(), dataType);
// Read vector-by-vector
if (dataType == CuVSMatrix.DataType.FLOAT) {
float[] vector = new float[fieldInfo.getVectorDimension()];
for (int i = 0; i < numVectors; ++i) {
input.readFloats(vector, 0, fieldInfo.getVectorDimension());
}
} else {
assert dataType == CuVSMatrix.DataType.BYTE;
byte[] vector = new byte[fieldInfo.getVectorDimension()];
for (int i = 0; i < numVectors; ++i) {
input.readBytes(vector, 0, fieldInfo.getVectorDimension());
}
}
dataset = builder.build();
}
} else {
// assert numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD : "numVectors: " + numVectors;
// we don't really need real value for vectors here,
// we just build a mock graph where every node is connected to every other node
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
datasetOrVectors = DatasetOrVectors.fromArray(vectors);
dataset = null;
}
try {
writeField(fieldInfo, dataset, numVectors);
} finally {
if (dataset != null) {
dataset.close();
}
}
writeFieldInternal(fieldInfo, datasetOrVectors);
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.gpu.reflect;

import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
import org.apache.lucene.store.IndexOutput;
import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;

public class VectorsFormatReflectionUtils {

private static final VarHandle FLAT_VECTOR_DATA_HANDLE;
private static final VarHandle QUANTIZED_VECTOR_DATA_HANDLE;
private static final VarHandle DELEGATE_WRITER_HANDLE;

static final Class<?> L99_SQ_VW_CLS = Lucene99ScalarQuantizedVectorsWriter.class;
static final Class<?> L99_F_VW_CLS = Lucene99FlatVectorsWriter.class;
static final Class<?> ES814_SQ_VW_CLS = ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter.class;

static {
try {
var lookup = MethodHandles.privateLookupIn(L99_F_VW_CLS, MethodHandles.lookup());
FLAT_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_F_VW_CLS, "vectorData", IndexOutput.class);

lookup = MethodHandles.privateLookupIn(L99_SQ_VW_CLS, MethodHandles.lookup());
QUANTIZED_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_SQ_VW_CLS, "quantizedVectorData", IndexOutput.class);

lookup = MethodHandles.privateLookupIn(ES814_SQ_VW_CLS, MethodHandles.lookup());
DELEGATE_WRITER_HANDLE = lookup.findVarHandle(ES814_SQ_VW_CLS, "delegate", L99_SQ_VW_CLS);

} catch (IllegalAccessException e) {
throw new AssertionError("should not happen, check opens", e);
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
}
}

public static IndexOutput getVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) {
assert flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter;
var delegate = (Lucene99ScalarQuantizedVectorsWriter) DELEGATE_WRITER_HANDLE.get(flatVectorWriter);
return (IndexOutput) QUANTIZED_VECTOR_DATA_HANDLE.get(delegate);
}

public static IndexOutput getQuantizedVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) {
assert flatVectorWriter instanceof Lucene99FlatVectorsWriter;
return (IndexOutput) FLAT_VECTOR_DATA_HANDLE.get(flatVectorWriter);
}
}