diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswSQVectorsFormat.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswSQVectorsFormat.java index 400a855db6d6b..a6e20e081a1dc 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswSQVectorsFormat.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswSQVectorsFormat.java @@ -61,7 +61,14 @@ public ESGpuHnswSQVectorsFormat(int maxConn, int beamWidth, Float confidenceInte @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ESGpuHnswVectorsWriter(cuVSResourceManager, state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); + return new ESGpuHnswVectorsWriter( + cuVSResourceManager, + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + flatVectorsFormat::fieldsReader + ); } @Override diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsFormat.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsFormat.java index b06b452435c83..c5d58c4a4f107 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsFormat.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsFormat.java @@ -55,7 +55,7 @@ public ESGpuHnswVectorsFormat() { public ESGpuHnswVectorsFormat(int maxConn, int beamWidth) { this(CuVSResourceManager.pooling(), maxConn, beamWidth); - }; + } public ESGpuHnswVectorsFormat(CuVSResourceManager cuVSResourceManager, int maxConn, int beamWidth) { super(NAME); @@ -66,7 +66,14 @@ public ESGpuHnswVectorsFormat(CuVSResourceManager cuVSResourceManager, int maxCo @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ESGpuHnswVectorsWriter(cuVSResourceManager, state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state)); + return new ESGpuHnswVectorsWriter( + cuVSResourceManager, + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + flatVectorsFormat::fieldsReader + ); } @Override diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java index df9b47ee5c62d..ab33739e7d1ca 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java @@ -15,14 +15,19 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; @@ -37,6 +42,7 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.packed.DirectMonotonicWriter; import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat; @@ -78,6 +84,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter { private final FlatVectorsWriter flatVectorWriter; private final List fields = new ArrayList<>(); + private final CheckedFunction flatVectorsReaderProvider; private boolean finished; private final CuVSMatrix.DataType dataType; @@ -86,8 +93,10 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter { SegmentWriteState state, int M, int beamWidth, - FlatVectorsWriter flatVectorWriter + FlatVectorsWriter flatVectorWriter, + CheckedFunction flatVectorsReaderProvider ) throws IOException { + this.flatVectorsReaderProvider = flatVectorsReaderProvider; assert cuVSResourceManager != null : "CuVSResources must not be null"; this.cuVSResourceManager = cuVSResourceManager; this.M = M; @@ -96,6 +105,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter { if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter) { dataType = CuVSMatrix.DataType.BYTE; } else { + assert flatVectorWriter instanceof Lucene99FlatVectorsWriter; dataType = CuVSMatrix.DataType.FLOAT; } this.segmentWriteState = state; @@ -145,14 +155,84 @@ public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException return newField; } + private static MemorySegmentAccessInput getMemorySegmentAccessInputOrNull(KnnVectorValues vectorValues) { + + if (vectorValues instanceof HasIndexSlice indexSlice) { + var input = FilterIndexInput.unwrapOnlyTest(indexSlice.getSlice()); + if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) { + return memorySegmentAccessInput; + } + } + return null; + } + @Override public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { flatVectorWriter.flush(maxDoc, sortMap); - for (FieldWriter field : fields) { - if (sortMap == null) { - writeField(field); - } else { - writeSortingField(field, sortMap); + + // TODO: this "mimics" a hypothetical/missing FlatVectorsWriter#getReader() + try ( + FlatVectorsReader flatVectorsReader = flatVectorsReaderProvider.apply( + new SegmentReadState( + segmentWriteState.segmentInfo.dir, + segmentWriteState.segmentInfo, + new FieldInfos(fields.stream().map(x -> x.fieldInfo).toArray(FieldInfo[]::new)), + IOContext.DEFAULT, + segmentWriteState.segmentSuffix + ) + ) + ) { + for (FieldWriter fieldWriter : fields) { + // This might be inefficient if getVectors() materializes a List; however current implementations + // just return a reference to an inner, already allocated List, so we are fine for now. + // TODO: change when/if Lucene introduces a direct FlatFieldVectorsWriter#size() + var numVectors = fieldWriter.flatFieldVectorsWriter.getVectors().size(); + + final CuVSMatrix dataset; + if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) { + if (dataType == CuVSMatrix.DataType.FLOAT) { + FloatVectorValues floatVectorValues = flatVectorsReader.getFloatVectorValues(fieldWriter.fieldInfo.name); + var memorySegmentAccessInput = getMemorySegmentAccessInputOrNull(floatVectorValues); + if (memorySegmentAccessInput != null) { + dataset = DatasetUtils.getInstance() + .fromInput(memorySegmentAccessInput, numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType); + } else { + var builder = CuVSMatrix.hostBuilder(numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType); + for (int i = 0; i < numVectors; ++i) { + builder.addVector(floatVectorValues.vectorValue(i)); + } + dataset = builder.build(); + } + } else { + assert dataType == CuVSMatrix.DataType.BYTE; + ByteVectorValues byteVectorValues = flatVectorsReader.getByteVectorValues(fieldWriter.fieldInfo.name); + var memorySegmentAccessInput = getMemorySegmentAccessInputOrNull(byteVectorValues); + if (memorySegmentAccessInput != null) { + dataset = DatasetUtils.getInstance() + .fromInput(memorySegmentAccessInput, numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType); + } else { + var builder = CuVSMatrix.hostBuilder(numVectors, fieldWriter.fieldInfo.getVectorDimension(), dataType); + for (int i = 0; i < numVectors; ++i) { + builder.addVector(byteVectorValues.vectorValue(i)); + } + dataset = builder.build(); + } + } + } else { + dataset = null; + } + + try { + if (sortMap == null) { + writeField(fieldWriter.fieldInfo, dataset, numVectors); + } else { + writeSortingField(fieldWriter.fieldInfo, dataset, numVectors, sortMap); + } + } finally { + if (dataset != null) { + dataset.close(); + } + } } } } @@ -185,83 +265,35 @@ public long ramBytesUsed() { return total; } - private static final class DatasetOrVectors { - private final CuVSMatrix dataset; - private final float[][] vectors; - - static DatasetOrVectors fromArray(float[][] vectors) { - return new DatasetOrVectors( - vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : CuVSMatrix.ofArray(vectors), - vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null - ); - } - - static DatasetOrVectors fromDataset(CuVSMatrix dataset) { - return new DatasetOrVectors(dataset, null); - } - - private DatasetOrVectors(CuVSMatrix dataset, float[][] vectors) { - this.dataset = dataset; - this.vectors = vectors; - validateState(); - } - - private void validateState() { - if ((dataset == null && vectors == null) || (dataset != null && vectors != null)) { - throw new IllegalStateException("Exactly one of dataset or vectors must be non-null"); - } - } - - int size() { - return dataset != null ? (int) dataset.size() : vectors.length; - } - - CuVSMatrix getDataset() { - return dataset; - } - - float[][] getVectors() { - return vectors; - } - } - - private void writeField(FieldWriter fieldWriter) throws IOException { - float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new); - writeFieldInternal(fieldWriter.fieldInfo, DatasetOrVectors.fromArray(vectors)); - } - - private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException { + private void writeSortingField(FieldInfo fieldInfo, CuVSMatrix datasetOrVectors, int size, Sorter.DocMap sortMap) throws IOException { // The flatFieldVectorsWriter's flush method, called before this, has already sorted the vectors according to the sortMap. // We can now treat them as a simple, sorted list of vectors. - float[][] vectors = fieldData.flatFieldVectorsWriter.getVectors().toArray(float[][]::new); - writeFieldInternal(fieldData.fieldInfo, DatasetOrVectors.fromArray(vectors)); + writeField(fieldInfo, datasetOrVectors, size); } - private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrVectors) throws IOException { + private void writeField(FieldInfo fieldInfo, CuVSMatrix dataset, int size) throws IOException { try { long vectorIndexOffset = vectorIndex.getFilePointer(); int[][] graphLevelNodeOffsets = new int[1][]; - HnswGraph mockGraph; - if (datasetOrVectors.getVectors() != null) { - int size = datasetOrVectors.size(); + final HnswGraph graph; + if (dataset == null) { if (logger.isDebugEnabled()) { logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", size, MIN_NUM_VECTORS_FOR_GPU_BUILD); } - mockGraph = writeGraph(size, graphLevelNodeOffsets); + graph = writeMockGraph(size, graphLevelNodeOffsets); } else { - var dataset = datasetOrVectors.getDataset(); var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType()); try { try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) { assert index != null : "GPU index should be built for field: " + fieldInfo.name; - mockGraph = writeGraph(index.getGraph(), graphLevelNodeOffsets); + graph = writeGraph(index.getGraph(), graphLevelNodeOffsets); } } finally { cuVSResourceManager.release(cuVSResources); } } long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; - writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetOrVectors.size(), mockGraph, graphLevelNodeOffsets); + writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, size, graph, graphLevelNodeOffsets); } catch (IOException e) { throw e; } catch (Throwable t) { @@ -337,7 +369,7 @@ private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) th } // create a mock graph where every node is connected to every other node - private HnswGraph writeGraph(int elementCount, int[][] levelNodeOffsets) throws IOException { + private HnswGraph writeMockGraph(int elementCount, int[][] levelNodeOffsets) throws IOException { if (elementCount == 0) { return null; } @@ -435,20 +467,43 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE } } try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) { - DatasetOrVectors datasetOrVectors; var input = FilterIndexInput.unwrapOnlyTest(in); - if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) { - var ds = DatasetUtils.getInstance() - .fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType); - datasetOrVectors = DatasetOrVectors.fromDataset(ds); + + final CuVSMatrix dataset; + if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) { + if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) { + // Direct access to mmapped file + dataset = DatasetUtils.getInstance() + .fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType); + } else { + var builder = CuVSMatrix.hostBuilder(numVectors, fieldInfo.getVectorDimension(), dataType); + // Read vector-by-vector + if (dataType == CuVSMatrix.DataType.FLOAT) { + float[] vector = new float[fieldInfo.getVectorDimension()]; + for (int i = 0; i < numVectors; ++i) { + input.readFloats(vector, 0, fieldInfo.getVectorDimension()); + } + } else { + assert dataType == CuVSMatrix.DataType.BYTE; + byte[] vector = new byte[fieldInfo.getVectorDimension()]; + for (int i = 0; i < numVectors; ++i) { + input.readBytes(vector, 0, fieldInfo.getVectorDimension()); + } + } + dataset = builder.build(); + } } else { - // assert numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD : "numVectors: " + numVectors; // we don't really need real value for vectors here, // we just build a mock graph where every node is connected to every other node - float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()]; - datasetOrVectors = DatasetOrVectors.fromArray(vectors); + dataset = null; + } + try { + writeField(fieldInfo, dataset, numVectors); + } finally { + if (dataset != null) { + dataset.close(); + } } - writeFieldInternal(fieldInfo, datasetOrVectors); } finally { org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); } diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/reflect/VectorsFormatReflectionUtils.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/reflect/VectorsFormatReflectionUtils.java new file mode 100644 index 0000000000000..7456a24b340c0 --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/reflect/VectorsFormatReflectionUtils.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.reflect; + +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter; +import org.apache.lucene.store.IndexOutput; +import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; + +public class VectorsFormatReflectionUtils { + + private static final VarHandle FLAT_VECTOR_DATA_HANDLE; + private static final VarHandle QUANTIZED_VECTOR_DATA_HANDLE; + private static final VarHandle DELEGATE_WRITER_HANDLE; + + static final Class L99_SQ_VW_CLS = Lucene99ScalarQuantizedVectorsWriter.class; + static final Class L99_F_VW_CLS = Lucene99FlatVectorsWriter.class; + static final Class ES814_SQ_VW_CLS = ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter.class; + + static { + try { + var lookup = MethodHandles.privateLookupIn(L99_F_VW_CLS, MethodHandles.lookup()); + FLAT_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_F_VW_CLS, "vectorData", IndexOutput.class); + + lookup = MethodHandles.privateLookupIn(L99_SQ_VW_CLS, MethodHandles.lookup()); + QUANTIZED_VECTOR_DATA_HANDLE = lookup.findVarHandle(L99_SQ_VW_CLS, "quantizedVectorData", IndexOutput.class); + + lookup = MethodHandles.privateLookupIn(ES814_SQ_VW_CLS, MethodHandles.lookup()); + DELEGATE_WRITER_HANDLE = lookup.findVarHandle(ES814_SQ_VW_CLS, "delegate", L99_SQ_VW_CLS); + + } catch (IllegalAccessException e) { + throw new AssertionError("should not happen, check opens", e); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } + + public static IndexOutput getVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) { + assert flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter; + var delegate = (Lucene99ScalarQuantizedVectorsWriter) DELEGATE_WRITER_HANDLE.get(flatVectorWriter); + return (IndexOutput) QUANTIZED_VECTOR_DATA_HANDLE.get(delegate); + } + + public static IndexOutput getQuantizedVectorDataIndexOutput(FlatVectorsWriter flatVectorWriter) { + assert flatVectorWriter instanceof Lucene99FlatVectorsWriter; + return (IndexOutput) FLAT_VECTOR_DATA_HANDLE.get(flatVectorWriter); + } +}