|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.gpu.codec; |
9 | 9 |
|
| 10 | +import com.nvidia.cuvs.CagraIndex; |
| 11 | +import com.nvidia.cuvs.CagraIndexParams; |
| 12 | +import com.nvidia.cuvs.CuVSResources; |
| 13 | + |
10 | 14 | import org.apache.lucene.codecs.CodecUtil; |
11 | 15 | import org.apache.lucene.codecs.KnnFieldVectorsWriter; |
12 | 16 | import org.apache.lucene.codecs.KnnVectorsWriter; |
|
15 | 19 | import org.apache.lucene.index.FieldInfo; |
16 | 20 | import org.apache.lucene.index.FloatVectorValues; |
17 | 21 | import org.apache.lucene.index.IndexFileNames; |
| 22 | +import org.apache.lucene.index.KnnVectorValues; |
18 | 23 | import org.apache.lucene.index.MergeState; |
19 | 24 | import org.apache.lucene.index.SegmentWriteState; |
20 | 25 | import org.apache.lucene.index.Sorter; |
21 | 26 | import org.apache.lucene.index.VectorEncoding; |
22 | 27 | import org.apache.lucene.index.VectorSimilarityFunction; |
23 | 28 | import org.apache.lucene.store.IndexOutput; |
| 29 | +import org.elasticsearch.common.lucene.store.IndexOutputOutputStream; |
24 | 30 | import org.elasticsearch.core.IOUtils; |
| 31 | +import org.elasticsearch.logging.LogManager; |
| 32 | +import org.elasticsearch.logging.Logger; |
25 | 33 |
|
26 | 34 | import java.io.IOException; |
27 | 35 | import java.util.ArrayList; |
28 | 36 | import java.util.List; |
29 | 37 |
|
30 | 38 | import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; |
| 39 | +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; |
31 | 40 |
|
32 | 41 | /** |
33 | 42 | * Writer for GPU-accelerated vectors. |
34 | 43 | */ |
35 | 44 | public class GPUVectorsWriter extends KnnVectorsWriter { |
| 45 | + private static final Logger logger = LogManager.getLogger(GPUVectorsWriter.class); |
| 46 | + // 2 for now based on https://github.com/rapidsai/cuvs/issues/666, but can be increased later |
| 47 | + private static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2; |
36 | 48 |
|
37 | 49 | private final List<FieldWriter> fieldWriters = new ArrayList<>(); |
38 | 50 | private final IndexOutput gpuIdx; |
39 | 51 | private final IndexOutput gpuMeta; |
40 | 52 | private final FlatVectorsWriter rawVectorDelegate; |
41 | 53 | private final SegmentWriteState segmentWriteState; |
| 54 | + private final CuVSResources cuVSResources; |
42 | 55 |
|
43 | 56 | @SuppressWarnings("this-escape") |
44 | 57 | public GPUVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { |
| 58 | + this.cuVSResources = GPUVectorsFormat.cuVSResourcesOrNull(); |
| 59 | + if (cuVSResources == null) { |
| 60 | + throw new IllegalArgumentException("GPU based vector search is not supported on this platform or java version"); |
| 61 | + } |
45 | 62 | this.segmentWriteState = state; |
46 | 63 | this.rawVectorDelegate = rawVectorDelegate; |
47 | 64 | final String metaFileName = IndexFileNames.segmentFileName( |
@@ -95,30 +112,89 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc |
95 | 112 | @Override |
96 | 113 | public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { |
97 | 114 | rawVectorDelegate.flush(maxDoc, sortMap); |
| 115 | + // TODO: implement the case when sortMap != null |
| 116 | + |
98 | 117 | for (FieldWriter fieldWriter : fieldWriters) { |
99 | | - // TODO: Implement GPU-specific vector merging instead of bogus implementation |
| 118 | + // TODO: can we use MemorySegment instead of passing array of vectors |
| 119 | + float[][] vectors = fieldWriter.delegate.getVectors().toArray(float[][]::new); |
100 | 120 | long dataOffset = gpuIdx.alignFilePointer(Float.BYTES); |
101 | | - var vectors = fieldWriter.delegate.getVectors(); |
102 | | - for (int i = 0; i < vectors.size(); i++) { |
103 | | - gpuIdx.writeVInt(0); |
| 121 | + try { |
| 122 | + buildAndwriteGPUIndex(fieldWriter.fieldInfo.getVectorSimilarityFunction(), vectors); |
| 123 | + long dataLength = gpuIdx.getFilePointer() - dataOffset; |
| 124 | + writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength); |
| 125 | + } catch (IOException e) { |
| 126 | + throw e; |
| 127 | + } catch (Throwable t) { |
| 128 | + throw new IOException("Failed to write GPU index: ", t); |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable { |
| 134 | + // TODO: should we Lucene HNSW index write here |
| 135 | + if (vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD) { |
| 136 | + if (logger.isDebugEnabled()) { |
| 137 | + logger.debug("Skip building carga index; vectors length {} < {}", vectors.length, MIN_NUM_VECTORS_FOR_GPU_BUILD); |
| 138 | + } |
| 139 | + return; |
| 140 | + } |
| 141 | + |
| 142 | + CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) { |
| 143 | + case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded; |
| 144 | + case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct; |
| 145 | + case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded; |
| 146 | + }; |
| 147 | + |
| 148 | + // TODO: expose cagra index params of intermediate graph degree, graph degre, algorithm, NNDescentNumIterations |
| 149 | + CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use? |
| 150 | + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) |
| 151 | + .withMetric(distanceType) |
| 152 | + .build(); |
| 153 | + |
| 154 | + // build index on GPU |
| 155 | + long startTime = System.nanoTime(); |
| 156 | + var index = CagraIndex.newBuilder(cuVSResources).withDataset(vectors).withIndexParams(params).build(); |
| 157 | + if (logger.isDebugEnabled()) { |
| 158 | + logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, vectors.length); |
| 159 | + } |
| 160 | + |
| 161 | + // TODO: do serialization through MemorySegment instead of a temp file |
| 162 | + // serialize index for CPU consumption |
| 163 | + startTime = System.nanoTime(); |
| 164 | + var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx); |
| 165 | + try { |
| 166 | + index.serialize(gpuIndexOutputStream); |
| 167 | + if (logger.isDebugEnabled()) { |
| 168 | + logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0); |
104 | 169 | } |
105 | | - long dataLength = gpuIdx.getFilePointer() - dataOffset; |
106 | | - writeMeta(fieldWriter.fieldInfo, dataOffset, dataLength); |
| 170 | + } finally { |
| 171 | + index.destroyIndex(); |
107 | 172 | } |
108 | 173 | } |
109 | 174 |
|
110 | 175 | @Override |
111 | 176 | public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { |
112 | 177 | if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { |
113 | 178 | rawVectorDelegate.mergeOneField(fieldInfo, mergeState); |
114 | | - // TODO: Implement GPU-specific vector merging instead of bogus implementation |
115 | | - FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); |
| 179 | + FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); |
| 180 | + // TODO: more efficient way to pass merged vector values to gpuIndex construction |
| 181 | + KnnVectorValues.DocIndexIterator iter = vectorValues.iterator(); |
| 182 | + List<float[]> vectorList = new ArrayList<>(); |
| 183 | + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { |
| 184 | + vectorList.add(vectorValues.vectorValue(iter.index())); |
| 185 | + } |
| 186 | + float[][] vectors = vectorList.toArray(new float[0][]); |
| 187 | + |
116 | 188 | long dataOffset = gpuIdx.alignFilePointer(Float.BYTES); |
117 | | - for (int i = 0; i < floatVectorValues.size(); i++) { |
118 | | - gpuIdx.writeVInt(0); |
| 189 | + try { |
| 190 | + buildAndwriteGPUIndex(fieldInfo.getVectorSimilarityFunction(), vectors); |
| 191 | + long dataLength = gpuIdx.getFilePointer() - dataOffset; |
| 192 | + writeMeta(fieldInfo, dataOffset, dataLength); |
| 193 | + } catch (IOException e) { |
| 194 | + throw e; |
| 195 | + } catch (Throwable t) { |
| 196 | + throw new IOException("Failed to write GPU index: ", t); |
119 | 197 | } |
120 | | - long dataLength = gpuIdx.getFilePointer() - dataOffset; |
121 | | - writeMeta(fieldInfo, dataOffset, dataLength); |
122 | 198 | } else { |
123 | 199 | rawVectorDelegate.mergeOneField(fieldInfo, mergeState); |
124 | 200 | } |
@@ -157,6 +233,7 @@ public final void finish() throws IOException { |
157 | 233 | @Override |
158 | 234 | public final void close() throws IOException { |
159 | 235 | IOUtils.close(rawVectorDelegate, gpuMeta, gpuIdx); |
| 236 | + cuVSResources.close(); |
160 | 237 | } |
161 | 238 |
|
162 | 239 | @Override |
|
0 commit comments