4242
4343import java .io .IOException ;
4444import java .io .UncheckedIOException ;
45+ import java .lang .foreign .Arena ;
46+ import java .lang .foreign .MemorySegment ;
4547import java .nio .ByteBuffer ;
4648import java .nio .ByteOrder ;
49+ import java .nio .channels .FileChannel ;
50+ import java .nio .file .Path ;
51+ import java .nio .file .StandardOpenOption ;
4752import java .util .ArrayList ;
4853import java .util .Arrays ;
4954import java .util .List ;
@@ -460,36 +465,56 @@ public NodesIterator getNodesOnLevel(int level) {
460465 };
461466 }
462467
463- // TODO check with deleted documents
464468 @ Override
465469 @ SuppressForbidden (reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)" )
466470 public void mergeOneField (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
471+ int dims = fieldInfo .getVectorDimension ();
467472 flatVectorWriter .mergeOneField (fieldInfo , mergeState );
468- FloatVectorValues vectorValues = KnnVectorsWriter .MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
469- // save merged vector values to a temp file
473+ FloatVectorValues mergeFloatVectorValues = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
474+
475+ if (mergeFloatVectorValues .size () < MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
476+ // TODO: check how deleted documents affect size value
477+ KnnVectorValues .DocIndexIterator iter = mergeFloatVectorValues .iterator ();
478+ float [] vector = new float [dims ];
479+ List <float []> vectorsList = new ArrayList <>();
480+ for (int docV = iter .nextDoc (); docV != NO_MORE_DOCS ; docV = iter .nextDoc ()) {
481+ System .arraycopy (mergeFloatVectorValues .vectorValue (iter .index ()), 0 , vector , 0 , dims );
482+ vectorsList .add (vector );
483+ }
484+ float [][] vectors = vectorsList .toArray (new float [0 ][]);
485+ DatasetOrVectors datasetOrVectors = new DatasetOrVectors (vectors );
486+ writeFieldInternal (fieldInfo , datasetOrVectors );
487+ return ;
488+ }
489+
490+
470491 final int numVectors ;
471492 String tempRawVectorsFileName = null ;
472493 boolean success = false ;
494+ // save merged vectors to a temporary file
473495 try (IndexOutput out = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "vec_" , IOContext .DEFAULT )) {
474496 tempRawVectorsFileName = out .getName ();
475- numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues . mergeFloatVectorValues ( fieldInfo , mergeState ) );
497+ numVectors = writeFloatVectorValues (fieldInfo , out , mergeFloatVectorValues );
476498 CodecUtil .writeFooter (out );
477499 success = true ;
478500 } finally {
479501 if (success == false && tempRawVectorsFileName != null ) {
480502 org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
481503 }
482504 }
483- try (IndexInput in = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT )) {
484- // TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485- final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , in , numVectors );
486- float [][] vectors = new float [numVectors ][fieldInfo .getVectorDimension ()];
487- float [] vector ;
488- for (int i = 0 ; i < numVectors ; i ++) {
489- vector = floatVectorValues .vectorValue (i );
490- System .arraycopy (vector , 0 , vectors [i ], 0 , vector .length );
491- }
492- DatasetOrVectors datasetOrVectors = new DatasetOrVectors (vectors );
505+ // Use MemorySegment to map the temp file and pass it as a dataset for building the GPU index
506+ try {
507+ final Path path = ((org .apache .lucene .store .FSDirectory ) mergeState .segmentInfo .dir ).getDirectory ().resolve (tempRawVectorsFileName );
508+ Arena arena = Arena .ofShared ();
509+ FileChannel fileChannel = FileChannel .open (path , StandardOpenOption .READ );
510+ final MemorySegment memorySegment = fileChannel .map (
511+ FileChannel .MapMode .READ_ONLY ,
512+ 0 ,
513+ fileChannel .size () - CodecUtil .footerLength (),
514+ arena
515+ );
516+ Dataset dataset = new DatasetImpl (arena , memorySegment , numVectors , fieldInfo .getVectorDimension ());
517+ DatasetOrVectors datasetOrVectors = new DatasetOrVectors (dataset , null );
493518 writeFieldInternal (fieldInfo , datasetOrVectors );
494519 } finally {
495520 org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
@@ -511,47 +536,6 @@ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out,
511536 return numVectors ;
512537 }
513538
514- private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput randomAccessInput , int numVectors ) {
515- if (numVectors == 0 ) {
516- return FloatVectorValues .fromFloats (List .of (), fieldInfo .getVectorDimension ());
517- }
518- final long length = (long ) Float .BYTES * fieldInfo .getVectorDimension () + Integer .BYTES ;
519- final float [] vector = new float [fieldInfo .getVectorDimension ()];
520- return new FloatVectorValues () {
521- @ Override
522- public float [] vectorValue (int ord ) throws IOException {
523- randomAccessInput .seek (ord * length + Integer .BYTES );
524- randomAccessInput .readFloats (vector , 0 , vector .length );
525- return vector ;
526- }
527-
528- @ Override
529- public FloatVectorValues copy () {
530- return this ;
531- }
532-
533- @ Override
534- public int dimension () {
535- return fieldInfo .getVectorDimension ();
536- }
537-
538- @ Override
539- public int size () {
540- return numVectors ;
541- }
542-
543- @ Override
544- public int ordToDoc (int ord ) {
545- try {
546- randomAccessInput .seek (ord * length );
547- return randomAccessInput .readInt ();
548- } catch (IOException e ) {
549- throw new UncheckedIOException (e );
550- }
551- }
552- };
553- }
554-
555539 private void writeMeta (
556540 FieldInfo field ,
557541 long vectorIndexOffset ,
0 commit comments