Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,83 +185,69 @@ public long ramBytesUsed() {
return total;
}

private static final class DatasetOrVectors {
private final CuVSMatrix dataset;
private final float[][] vectors;

static DatasetOrVectors fromArray(float[][] vectors) {
return new DatasetOrVectors(
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : CuVSMatrix.ofArray(vectors),
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null
);
}

static DatasetOrVectors fromDataset(CuVSMatrix dataset) {
return new DatasetOrVectors(dataset, null);
}

private DatasetOrVectors(CuVSMatrix dataset, float[][] vectors) {
this.dataset = dataset;
this.vectors = vectors;
validateState();
}
private void writeField(FieldWriter fieldWriter) throws IOException {

private void validateState() {
if ((dataset == null && vectors == null) || (dataset != null && vectors != null)) {
throw new IllegalStateException("Exactly one of dataset or vectors must be non-null");
// if (dataType == CuVSMatrix.DataType.FLOAT) {
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
final CuVSMatrix dataset = vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : CuVSMatrix.ofArray(vectors);
try {
writeFieldInternal(fieldWriter.fieldInfo, dataset, vectors.length);
} finally {
if (dataset != null) {
dataset.close();
}
}

int size() {
return dataset != null ? (int) dataset.size() : vectors.length;
}

CuVSMatrix getDataset() {
return dataset;
}

float[][] getVectors() {
return vectors;
}
// } else {
// throw new UnsupportedOperationException("BYTE is still unsupported");
// }
}

private void writeField(FieldWriter fieldWriter) throws IOException {
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
writeFieldInternal(fieldWriter.fieldInfo, DatasetOrVectors.fromArray(vectors));
}

private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
private void writeSortingField(FieldWriter fieldWriter, Sorter.DocMap sortMap) throws IOException {
// The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap.
// We can now treat them as a simple, sorted list of vectors.
float[][] vectors = fieldData.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
writeFieldInternal(fieldData.fieldInfo, DatasetOrVectors.fromArray(vectors));

// if (dataType == CuVSMatrix.DataType.FLOAT) {
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
final CuVSMatrix dataset = vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : CuVSMatrix.ofArray(vectors);
try {
writeFieldInternal(fieldWriter.fieldInfo, dataset, vectors.length);
} finally {
if (dataset != null) {
dataset.close();
}
}
// } else {
// throw new UnsupportedOperationException("BYTE is still unsupported");
// }
}

private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException {
private void writeFieldInternal(FieldInfo fieldInfo, CuVSMatrix dataset, int datasetSize) throws IOException {
try {
long vectorIndexOffset = vectorIndex.getFilePointer();
int[][] graphLevelNodeOffsets = new int[1][];
HnswGraph mockGraph;
if (datasetOrVectors.getVectors() != null) {
int size = datasetOrVectors.size();
final HnswGraph graph;
if (dataset == null) {
if (logger.isDebugEnabled()) {
logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", size, MIN_NUM_VECTORS_FOR_GPU_BUILD);
logger.debug(
"Skip building carga index; vectors length {} < {} (min for GPU)",
datasetSize,
MIN_NUM_VECTORS_FOR_GPU_BUILD
);
}
mockGraph = writeGraph(size, graphLevelNodeOffsets);
graph = writeMockGraph(datasetSize, graphLevelNodeOffsets);
} else {
var dataset = datasetOrVectors.getDataset();
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType());
try {
try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
assert index != null : "GPU index should be built for field: " + fieldInfo.name;
mockGraph = writeGraph(index.getGraph(), graphLevelNodeOffsets);
graph = writeGraph(index.getGraph(), graphLevelNodeOffsets);
}
} finally {
cuVSResourceManager.release(cuVSResources);
}
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetOrVectors.size(), mockGraph, graphLevelNodeOffsets);
writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetSize, graph, graphLevelNodeOffsets);
} catch (IOException e) {
throw e;
} catch (Throwable t) {
Expand Down Expand Up @@ -337,7 +323,7 @@ private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) th
}

// create a mock graph where every node is connected to every other node
private HnswGraph writeGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
private HnswGraph writeMockGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
if (elementCount == 0) {
return null;
}
Expand Down Expand Up @@ -435,20 +421,43 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
}
}
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
DatasetOrVectors datasetOrVectors;
var input = FilterIndexInput.unwrapOnlyTest(in);
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
var ds = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
datasetOrVectors = DatasetOrVectors.fromDataset(ds);

final CuVSMatrix dataset;
if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
// Direct access to mmapped file
dataset = DatasetUtils.getInstance()
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
} else {
var builder = CuVSMatrix.hostBuilder(numVectors, fieldInfo.getVectorDimension(), dataType);
// Read vector-by-vector
if (dataType == CuVSMatrix.DataType.FLOAT) {
float[] vector = new float[fieldInfo.getVectorDimension()];
for (int i = 0; i < numVectors; ++i) {
input.readFloats(vector, 0, fieldInfo.getVectorDimension());
}
} else {
assert dataType == CuVSMatrix.DataType.BYTE;
byte[] vector = new byte[fieldInfo.getVectorDimension()];
for (int i = 0; i < numVectors; ++i) {
input.readBytes(vector, 0, fieldInfo.getVectorDimension());
}
}
dataset = builder.build();
}
} else {
// assert numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD : "numVectors: " + numVectors;
// we don't really need real value for vectors here,
// we just build a mock graph where every node is connected to every other node
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
datasetOrVectors = DatasetOrVectors.fromArray(vectors);
dataset = null;
}
try {
writeFieldInternal(fieldInfo, dataset, numVectors);
} finally {
if (dataset != null) {
dataset.close();
}
}
writeFieldInternal(fieldInfo, datasetOrVectors);
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.xpack.gpu.GPUSupport;
import org.junit.BeforeClass;

@LuceneTestCase.SuppressSysoutChecks(bugUrl = "https://github.com/rapidsai/cuvs/issues/1310")
public class ESGpuHnswSQVectorsFormatTests extends BaseKnnVectorsFormatTestCase {

static {
Expand Down