2828import org .apache .lucene .store .IOContext ;
2929import org .apache .lucene .store .IndexInput ;
3030import org .apache .lucene .store .IndexOutput ;
31+ import org .apache .lucene .store .RandomAccessInput ;
3132import org .apache .lucene .util .VectorUtil ;
3233import org .elasticsearch .core .IOUtils ;
3334import org .elasticsearch .core .SuppressForbidden ;
@@ -237,36 +238,60 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
237238 private void mergeOneFieldIVF (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
238239 final int numVectors ;
239240 String tempRawVectorsFileName = null ;
241+ String docsFileName = null ;
240242 boolean success = false ;
241243 // build a float vector values with random access. In order to do that we dump the vectors to
242- // a temporary file
243- // and write the docID follow by the vector
244- try (IndexOutput out = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivf_" , IOContext .DEFAULT )) {
245- tempRawVectorsFileName = out .getName ();
246- // TODO do this better, we shouldn't have to write to a temp file, we should be able to
247- // to just from the merged vector values, the tricky part is the random access.
248- numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState ));
249- CodecUtil .writeFooter (out );
250- success = true ;
244+ // a temporary file and if the segment is not dense, the docs to another file/
245+ try (
246+ IndexOutput vectorsOut = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivfvec_" , IOContext .DEFAULT )
247+ ) {
248+ tempRawVectorsFileName = vectorsOut .getName ();
249+ FloatVectorValues mergedFloatVectorValues = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
250+ // if the segment is dense, we don't need to do anything with docIds.
251+ boolean dense = mergedFloatVectorValues .size () == mergeState .segmentInfo .maxDoc ();
252+ try (
253+ IndexOutput docsOut = dense
254+ ? null
255+ : mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivfdoc_" , IOContext .DEFAULT )
256+ ) {
257+ if (docsOut != null ) {
258+ docsFileName = docsOut .getName ();
259+ }
260+ // TODO do this better, we shouldn't have to write to a temp file, we should be able to
261+ // to just from the merged vector values, the tricky part is the random access.
262+ numVectors = writeFloatVectorValues (fieldInfo , docsOut , vectorsOut , mergedFloatVectorValues );
263+ CodecUtil .writeFooter (vectorsOut );
264+ if (docsOut != null ) {
265+ CodecUtil .writeFooter (docsOut );
266+ }
267+ success = true ;
268+ }
251269 } finally {
252- if (success == false && tempRawVectorsFileName != null ) {
253- org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
270+ if (success == false ) {
271+ if (tempRawVectorsFileName != null ) {
272+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
273+ }
274+ if (docsFileName != null ) {
275+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , docsFileName );
276+ }
254277 }
255278 }
256- try (IndexInput in = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT )) {
257- float [] calculatedGlobalCentroid = new float [fieldInfo .getVectorDimension ()];
258- final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , in , numVectors );
279+ try (
280+ IndexInput vectors = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT );
281+ IndexInput docs = docsFileName == null ? null : mergeState .segmentInfo .dir .openInput (docsFileName , IOContext .DEFAULT )
282+ ) {
283+ final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , docs , vectors , numVectors );
259284 success = false ;
260285 long centroidOffset ;
261286 long centroidLength ;
262287 String centroidTempName = null ;
263288 int numCentroids ;
264289 IndexOutput centroidTemp = null ;
265290 CentroidAssignments centroidAssignments ;
291+ float [] calculatedGlobalCentroid = new float [fieldInfo .getVectorDimension ()];
266292 try {
267293 centroidTemp = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "civf_" , IOContext .DEFAULT );
268294 centroidTempName = centroidTemp .getName ();
269-
270295 centroidAssignments = calculateAndWriteCentroids (
271296 fieldInfo ,
272297 floatVectorValues ,
@@ -318,28 +343,34 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
318343 writeMeta (fieldInfo , centroidOffset , centroidLength , offsets , calculatedGlobalCentroid );
319344 }
320345 } finally {
346+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , centroidTempName );
347+ }
348+ } finally {
349+ if (docsFileName != null ) {
321350 org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (
322351 mergeState .segmentInfo .dir ,
323352 tempRawVectorsFileName ,
324- centroidTempName
353+ docsFileName
325354 );
355+ } else {
356+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
326357 }
327- } finally {
328- org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
329358 }
330359 }
331360
332- private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput randomAccessInput , int numVectors ) {
361+ private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput docs , IndexInput vectors , int numVectors )
362+ throws IOException {
333363 if (numVectors == 0 ) {
334364 return FloatVectorValues .fromFloats (List .of (), fieldInfo .getVectorDimension ());
335365 }
336- final long length = (long ) Float .BYTES * fieldInfo .getVectorDimension () + Integer . BYTES ;
366+ final long vectorLength = (long ) Float .BYTES * fieldInfo .getVectorDimension ();
337367 final float [] vector = new float [fieldInfo .getVectorDimension ()];
368+ final RandomAccessInput randomDocs = docs == null ? null : docs .randomAccessSlice (0 , docs .length ());
338369 return new FloatVectorValues () {
339370 @ Override
340371 public float [] vectorValue (int ord ) throws IOException {
341- randomAccessInput .seek (ord * length + Integer . BYTES );
342- randomAccessInput .readFloats (vector , 0 , vector .length );
372+ vectors .seek (ord * vectorLength );
373+ vectors .readFloats (vector , 0 , vector .length );
343374 return vector ;
344375 }
345376
@@ -360,27 +391,34 @@ public int size() {
360391
361392 @ Override
362393 public int ordToDoc (int ord ) {
394+ if (randomDocs == null ) {
395+ return ord ;
396+ }
363397 try {
364- randomAccessInput .seek (ord * length );
365- return randomAccessInput .readInt ();
398+ return randomDocs .readInt ((long ) ord * Integer .BYTES );
366399 } catch (IOException e ) {
367400 throw new UncheckedIOException (e );
368401 }
369402 }
370403 };
371404 }
372405
373- private static int writeFloatVectorValues (FieldInfo fieldInfo , IndexOutput out , FloatVectorValues floatVectorValues )
374- throws IOException {
406+ private static int writeFloatVectorValues (
407+ FieldInfo fieldInfo ,
408+ IndexOutput docsOut ,
409+ IndexOutput vectorsOut ,
410+ FloatVectorValues floatVectorValues
411+ ) throws IOException {
375412 int numVectors = 0 ;
376413 final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
377414 final KnnVectorValues .DocIndexIterator iterator = floatVectorValues .iterator ();
378415 for (int docV = iterator .nextDoc (); docV != NO_MORE_DOCS ; docV = iterator .nextDoc ()) {
379416 numVectors ++;
380- float [] vector = floatVectorValues .vectorValue (iterator .index ());
381- out .writeInt (iterator .docID ());
382- buffer .asFloatBuffer ().put (vector );
383- out .writeBytes (buffer .array (), buffer .array ().length );
417+ buffer .asFloatBuffer ().put (floatVectorValues .vectorValue (iterator .index ()));
418+ vectorsOut .writeBytes (buffer .array (), buffer .array ().length );
419+ if (docsOut != null ) {
420+ docsOut .writeInt (iterator .docID ());
421+ }
384422 }
385423 return numVectors ;
386424 }
0 commit comments