Skip to content

Commit 3f05ba0

Browse files
Build cagra index (iter2)
1 parent 8cee75e commit 3f05ba0

File tree

1 file changed

+87
-4
lines changed

1 file changed

+87
-4
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsWriter.java

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
import org.apache.lucene.index.Sorter;
2626
import org.apache.lucene.index.VectorEncoding;
2727
import org.apache.lucene.index.VectorSimilarityFunction;
28+
import org.apache.lucene.store.IndexInput;
2829
import org.apache.lucene.store.IndexOutput;
2930
import org.elasticsearch.common.lucene.store.IndexOutputOutputStream;
3031
import org.elasticsearch.core.IOUtils;
32+
import org.elasticsearch.core.SuppressForbidden;
3133
import org.elasticsearch.logging.LogManager;
3234
import org.elasticsearch.logging.Logger;
3335

@@ -130,6 +132,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
130132
}
131133
}
132134

135+
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
133136
private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction, float[][] vectors) throws Throwable {
134137
// TODO: should we Lucene HNSW index write here
135138
if (vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
@@ -139,6 +142,7 @@ private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction,
139142
return;
140143
}
141144

145+
int dimension = vectors[0].length;
142146
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
143147
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
144148
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
@@ -159,16 +163,95 @@ private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction,
159163
}
160164

161165
// TODO: do serialization through MemorySegment instead of a temp file
162-
// serialize index for CPU consumption
166+
// serialize index for CPU consumption to hnwslib format
163167
startTime = System.nanoTime();
164-
var gpuIndexOutputStream = new IndexOutputOutputStream(gpuIdx);
168+
IndexOutput tempCagraHNSW = null;
169+
boolean success = false;
165170
try {
166-
index.serialize(gpuIndexOutputStream);
171+
tempCagraHNSW = segmentWriteState.directory.createTempOutput(gpuIdx.getName(), "cagra_hnws_temp", segmentWriteState.context);
172+
var tempCagraHNSWOutputStream = new IndexOutputOutputStream(tempCagraHNSW);
173+
index.serializeToHNSW(tempCagraHNSWOutputStream);
174+
success = true;
167175
if (logger.isDebugEnabled()) {
168-
logger.debug("Carga index serialized in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
176+
logger.debug("Carga index serialized to hnswlib format in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
169177
}
170178
} finally {
171179
index.destroyIndex();
180+
if (success) {
181+
IOUtils.close(tempCagraHNSW);
182+
} else {
183+
IOUtils.closeWhileHandlingException(tempCagraHNSW);
184+
if (tempCagraHNSW != null) {
185+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempCagraHNSW.getName());
186+
}
187+
}
188+
}
189+
190+
// convert hnswlib format to Lucene HNSW format
191+
startTime = System.nanoTime();
192+
success = false;
193+
IndexInput tempCagraHNSWInput = null;
194+
try {
195+
tempCagraHNSWInput = segmentWriteState.directory.openInput(tempCagraHNSW.getName(), segmentWriteState.context);
196+
// read the metadata from the hnlswlib format
197+
// some of them are not used in Lucene HNSW format
198+
tempCagraHNSWInput.readLong(); // offSetLevel0
199+
long maxElementCount = tempCagraHNSWInput.readLong();
200+
tempCagraHNSWInput.readLong(); // currElementCount
201+
long sizeDataPerElement = tempCagraHNSWInput.readLong();
202+
long labelOffset = tempCagraHNSWInput.readLong();
203+
long dataOffset = tempCagraHNSWInput.readLong();
204+
int maxLevel = tempCagraHNSWInput.readInt();
205+
tempCagraHNSWInput.readInt(); // entryPointNode
206+
tempCagraHNSWInput.readLong(); // maxM
207+
long maxM0 = tempCagraHNSWInput.readLong(); // number of graph connections
208+
tempCagraHNSWInput.readLong(); // M
209+
tempCagraHNSWInput.readLong(); // mult
210+
tempCagraHNSWInput.readLong(); // efConstruction
211+
212+
assert (maxLevel == 1) : "Cagra index is flat, maxLevel must be: 1, got: " + maxLevel;
213+
int maxGraphDegree = (int) maxM0;
214+
int[] connections = new int[maxGraphDegree];
215+
int dimensionCalculated = (int) ((labelOffset - dataOffset) / Float.BYTES);
216+
assert (dimension == dimensionCalculated)
217+
: "Cagra index vector dimension must be: " + dimension + ", got: " + dimensionCalculated;
218+
219+
// read graph from the cagra_hnswlib index and write it to the Lucene HNSW format
220+
gpuIdx.writeInt((int) maxElementCount);
221+
gpuIdx.writeInt((int) maxM0);
222+
for (int i = 0; i < maxElementCount; i++) {
223+
// read from the cagra_hnswlib index
224+
int graphDegree = tempCagraHNSWInput.readInt();
225+
assert (graphDegree == maxGraphDegree)
226+
: "In Cagra graph all nodes must have the same number of connections : " + maxGraphDegree + ", got" + graphDegree;
227+
for (int j = 0; j < graphDegree; j++) {
228+
connections[j] = tempCagraHNSWInput.readInt();
229+
}
230+
// Skip over the vector data
231+
tempCagraHNSWInput.seek(tempCagraHNSWInput.getFilePointer() + dimension * Float.BYTES);
232+
// Skip over the label/id
233+
tempCagraHNSWInput.seek(tempCagraHNSWInput.getFilePointer() + Long.BYTES);
234+
235+
// write graph
236+
gpuIdx.writeVInt(graphDegree);
237+
for (int neighbor : connections) {
238+
gpuIdx.writeVInt(neighbor);
239+
}
240+
}
241+
242+
success = true;
243+
if (logger.isDebugEnabled()) {
244+
logger.debug("cagra_hnws index serialized to Lucene HNSW in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
245+
}
246+
} finally {
247+
if (success) {
248+
IOUtils.close(tempCagraHNSWInput);
249+
} else {
250+
IOUtils.closeWhileHandlingException(tempCagraHNSWInput);
251+
}
252+
if (tempCagraHNSW != null) {
253+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempCagraHNSW.getName());
254+
}
172255
}
173256
}
174257

0 commit comments

Comments
 (0)