2525import org .apache .lucene .index .Sorter ;
2626import org .apache .lucene .index .VectorEncoding ;
2727import org .apache .lucene .index .VectorSimilarityFunction ;
28+ import org .apache .lucene .store .IndexInput ;
2829import org .apache .lucene .store .IndexOutput ;
2930import org .elasticsearch .common .lucene .store .IndexOutputOutputStream ;
3031import org .elasticsearch .core .IOUtils ;
32+ import org .elasticsearch .core .SuppressForbidden ;
3133import org .elasticsearch .logging .LogManager ;
3234import org .elasticsearch .logging .Logger ;
3335
@@ -130,6 +132,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
130132 }
131133 }
132134
135+ @ SuppressForbidden (reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)" )
133136 private void buildAndwriteGPUIndex (VectorSimilarityFunction similarityFunction , float [][] vectors ) throws Throwable {
134137 // TODO: should we Lucene HNSW index write here
135138 if (vectors .length < MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
@@ -139,6 +142,7 @@ private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction,
139142 return ;
140143 }
141144
145+ int dimension = vectors [0 ].length ;
142146 CagraIndexParams .CuvsDistanceType distanceType = switch (similarityFunction ) {
143147 case EUCLIDEAN -> CagraIndexParams .CuvsDistanceType .L2Expanded ;
144148 case DOT_PRODUCT , MAXIMUM_INNER_PRODUCT -> CagraIndexParams .CuvsDistanceType .InnerProduct ;
@@ -159,16 +163,95 @@ private void buildAndwriteGPUIndex(VectorSimilarityFunction similarityFunction,
159163 }
160164
161165 // TODO: do serialization through MemorySegment instead of a temp file
162- // serialize index for CPU consumption
166+ // serialize index for CPU consumption to hnwslib format
163167 startTime = System .nanoTime ();
164- var gpuIndexOutputStream = new IndexOutputOutputStream (gpuIdx );
168+ IndexOutput tempCagraHNSW = null ;
169+ boolean success = false ;
165170 try {
166- index .serialize (gpuIndexOutputStream );
171+ tempCagraHNSW = segmentWriteState .directory .createTempOutput (gpuIdx .getName (), "cagra_hnws_temp" , segmentWriteState .context );
172+ var tempCagraHNSWOutputStream = new IndexOutputOutputStream (tempCagraHNSW );
173+ index .serializeToHNSW (tempCagraHNSWOutputStream );
174+ success = true ;
167175 if (logger .isDebugEnabled ()) {
168- logger .debug ("Carga index serialized in: {} ms" , (System .nanoTime () - startTime ) / 1_000_000.0 );
176+ logger .debug ("Carga index serialized to hnswlib format in: {} ms" , (System .nanoTime () - startTime ) / 1_000_000.0 );
169177 }
170178 } finally {
171179 index .destroyIndex ();
180+ if (success ) {
181+ IOUtils .close (tempCagraHNSW );
182+ } else {
183+ IOUtils .closeWhileHandlingException (tempCagraHNSW );
184+ if (tempCagraHNSW != null ) {
185+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (segmentWriteState .directory , tempCagraHNSW .getName ());
186+ }
187+ }
188+ }
189+
190+ // convert hnswlib format to Lucene HNSW format
191+ startTime = System .nanoTime ();
192+ success = false ;
193+ IndexInput tempCagraHNSWInput = null ;
194+ try {
195+ tempCagraHNSWInput = segmentWriteState .directory .openInput (tempCagraHNSW .getName (), segmentWriteState .context );
196+ // read the metadata from the hnlswlib format
197+ // some of them are not used in Lucene HNSW format
198+ tempCagraHNSWInput .readLong (); // offSetLevel0
199+ long maxElementCount = tempCagraHNSWInput .readLong ();
200+ tempCagraHNSWInput .readLong (); // currElementCount
201+ long sizeDataPerElement = tempCagraHNSWInput .readLong ();
202+ long labelOffset = tempCagraHNSWInput .readLong ();
203+ long dataOffset = tempCagraHNSWInput .readLong ();
204+ int maxLevel = tempCagraHNSWInput .readInt ();
205+ tempCagraHNSWInput .readInt (); // entryPointNode
206+ tempCagraHNSWInput .readLong (); // maxM
207+ long maxM0 = tempCagraHNSWInput .readLong (); // number of graph connections
208+ tempCagraHNSWInput .readLong (); // M
209+ tempCagraHNSWInput .readLong (); // mult
210+ tempCagraHNSWInput .readLong (); // efConstruction
211+
212+ assert (maxLevel == 1 ) : "Cagra index is flat, maxLevel must be: 1, got: " + maxLevel ;
213+ int maxGraphDegree = (int ) maxM0 ;
214+ int [] connections = new int [maxGraphDegree ];
215+ int dimensionCalculated = (int ) ((labelOffset - dataOffset ) / Float .BYTES );
216+ assert (dimension == dimensionCalculated )
217+ : "Cagra index vector dimension must be: " + dimension + ", got: " + dimensionCalculated ;
218+
219+ // read graph from the cagra_hnswlib index and write it to the Lucene HNSW format
220+ gpuIdx .writeInt ((int ) maxElementCount );
221+ gpuIdx .writeInt ((int ) maxM0 );
222+ for (int i = 0 ; i < maxElementCount ; i ++) {
223+ // read from the cagra_hnswlib index
224+ int graphDegree = tempCagraHNSWInput .readInt ();
225+ assert (graphDegree == maxGraphDegree )
226+ : "In Cagra graph all nodes must have the same number of connections : " + maxGraphDegree + ", got" + graphDegree ;
227+ for (int j = 0 ; j < graphDegree ; j ++) {
228+ connections [j ] = tempCagraHNSWInput .readInt ();
229+ }
230+ // Skip over the vector data
231+ tempCagraHNSWInput .seek (tempCagraHNSWInput .getFilePointer () + dimension * Float .BYTES );
232+ // Skip over the label/id
233+ tempCagraHNSWInput .seek (tempCagraHNSWInput .getFilePointer () + Long .BYTES );
234+
235+ // write graph
236+ gpuIdx .writeVInt (graphDegree );
237+ for (int neighbor : connections ) {
238+ gpuIdx .writeVInt (neighbor );
239+ }
240+ }
241+
242+ success = true ;
243+ if (logger .isDebugEnabled ()) {
244+ logger .debug ("cagra_hnws index serialized to Lucene HNSW in: {} ms" , (System .nanoTime () - startTime ) / 1_000_000.0 );
245+ }
246+ } finally {
247+ if (success ) {
248+ IOUtils .close (tempCagraHNSWInput );
249+ } else {
250+ IOUtils .closeWhileHandlingException (tempCagraHNSWInput );
251+ }
252+ if (tempCagraHNSW != null ) {
253+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (segmentWriteState .directory , tempCagraHNSW .getName ());
254+ }
172255 }
173256 }
174257
0 commit comments