1010import com .nvidia .cuvs .CagraIndex ;
1111import com .nvidia .cuvs .CagraIndexParams ;
1212import com .nvidia .cuvs .CuVSResources ;
13+ import com .nvidia .cuvs .Dataset ;
1314
1415import org .apache .lucene .codecs .CodecUtil ;
1516import org .apache .lucene .codecs .KnnFieldVectorsWriter ;
2627import org .apache .lucene .index .Sorter ;
2728import org .apache .lucene .index .VectorEncoding ;
2829import org .apache .lucene .index .VectorSimilarityFunction ;
30+ import org .apache .lucene .store .IOContext ;
2931import org .apache .lucene .store .IndexInput ;
3032import org .apache .lucene .store .IndexOutput ;
3133import org .apache .lucene .util .RamUsageEstimator ;
3941import org .elasticsearch .logging .Logger ;
4042
4143import java .io .IOException ;
44+ import java .io .UncheckedIOException ;
45+ import java .nio .ByteBuffer ;
46+ import java .nio .ByteOrder ;
4247import java .util .ArrayList ;
4348import java .util .Arrays ;
4449import java .util .List ;
@@ -166,9 +171,46 @@ public long ramBytesUsed() {
166171 return total ;
167172 }
168173
174+ private static final class DatasetOrVectors {
175+ private final Dataset dataset ;
176+ private final float [][] vectors ;
177+
178+ DatasetOrVectors (float [][] vectors ) {
179+ this (
180+ vectors .length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : Dataset .ofArray (vectors ),
181+ vectors .length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null
182+ );
183+ validateState ();
184+ }
185+
186+ private DatasetOrVectors (Dataset dataset , float [][] vectors ) {
187+ this .dataset = dataset ;
188+ this .vectors = vectors ;
189+ validateState ();
190+ }
191+
192+ private void validateState () {
193+ if ((dataset == null && vectors == null ) || (dataset != null && vectors != null )) {
194+ throw new IllegalStateException ("Exactly one of dataset or vectors must be non-null" );
195+ }
196+ }
197+
198+ int size () {
199+ return dataset != null ? dataset .size () : vectors .length ;
200+ }
201+
202+ Dataset getDataset () {
203+ return dataset ;
204+ }
205+
206+ float [][] getVectors () {
207+ return vectors ;
208+ }
209+ }
210+
169211 private void writeField (FieldWriter fieldWriter ) throws IOException {
170212 float [][] vectors = fieldWriter .flatFieldVectorsWriter .getVectors ().toArray (float [][]::new );
171- writeFieldInternal (fieldWriter .fieldInfo , vectors );
213+ writeFieldInternal (fieldWriter .fieldInfo , new DatasetOrVectors ( vectors ) );
172214 }
173215
174216 private void writeSortingField (FieldWriter fieldData , Sorter .DocMap sortMap ) throws IOException {
@@ -177,12 +219,13 @@ private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) thr
177219 throw new UnsupportedOperationException ("Writing field with index sorted needs to be implemented." );
178220 }
179221
180- private void writeFieldInternal (FieldInfo fieldInfo , float [][] vectors ) throws IOException {
222+ private void writeFieldInternal (FieldInfo fieldInfo , DatasetOrVectors datasetOrVectors ) throws IOException {
181223 try {
182224 long vectorIndexOffset = vectorIndex .getFilePointer ();
183225 int [][] graphLevelNodeOffsets = new int [1 ][];
184226 HnswGraph mockGraph ;
185- if (vectors .length < MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
227+ if (datasetOrVectors .vectors != null ) {
228+ float [][] vectors = datasetOrVectors .vectors ;
186229 if (logger .isDebugEnabled ()) {
187230 logger .debug (
188231 "Skip building carga index; vectors length {} < {} (min for GPU)" ,
@@ -192,12 +235,12 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I
192235 }
193236 mockGraph = writeGraph (vectors , graphLevelNodeOffsets );
194237 } else {
195- String tempCagraHNSWFileName = buildGPUIndex (fieldInfo .getVectorSimilarityFunction (), vectors );
238+ String tempCagraHNSWFileName = buildGPUIndex (fieldInfo .getVectorSimilarityFunction (), datasetOrVectors . dataset );
196239 assert tempCagraHNSWFileName != null : "GPU index should be built for field: " + fieldInfo .name ;
197240 mockGraph = writeGraph (tempCagraHNSWFileName , graphLevelNodeOffsets );
198241 }
199242 long vectorIndexLength = vectorIndex .getFilePointer () - vectorIndexOffset ;
200- writeMeta (fieldInfo , vectorIndexOffset , vectorIndexLength , vectors . length , mockGraph , graphLevelNodeOffsets );
243+ writeMeta (fieldInfo , vectorIndexOffset , vectorIndexLength , datasetOrVectors . size () , mockGraph , graphLevelNodeOffsets );
201244 } catch (IOException e ) {
202245 throw e ;
203246 } catch (Throwable t ) {
@@ -206,7 +249,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I
206249 }
207250
208251 @ SuppressForbidden (reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)" )
209- private String buildGPUIndex (VectorSimilarityFunction similarityFunction , float [][] vectors ) throws Throwable {
252+ private String buildGPUIndex (VectorSimilarityFunction similarityFunction , Dataset dataset ) throws Throwable {
210253 CagraIndexParams .CuvsDistanceType distanceType = switch (similarityFunction ) {
211254 case EUCLIDEAN -> CagraIndexParams .CuvsDistanceType .L2Expanded ;
212255 case DOT_PRODUCT , MAXIMUM_INNER_PRODUCT -> CagraIndexParams .CuvsDistanceType .InnerProduct ;
@@ -221,9 +264,9 @@ private String buildGPUIndex(VectorSimilarityFunction similarityFunction, float[
221264
222265 // build index on GPU
223266 long startTime = System .nanoTime ();
224- var index = CagraIndex .newBuilder (cuVSResources ).withDataset (vectors ).withIndexParams (params ).build ();
267+ var index = CagraIndex .newBuilder (cuVSResources ).withDataset (dataset ).withIndexParams (params ).build ();
225268 if (logger .isDebugEnabled ()) {
226- logger .debug ("Carga index created in: {} ms; #num vectors: {}" , (System .nanoTime () - startTime ) / 1_000_000.0 , vectors . length );
269+ logger .debug ("Carga index created in: {} ms; #num vectors: {}" , (System .nanoTime () - startTime ) / 1_000_000.0 , dataset . size () );
227270 }
228271
229272 // TODO: do serialization through MemorySegment instead of a temp file
@@ -419,18 +462,94 @@ public NodesIterator getNodesOnLevel(int level) {
419462
420463 // TODO check with deleted documents
421464 @ Override
465+ @ SuppressForbidden (reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)" )
422466 public void mergeOneField (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
423467 flatVectorWriter .mergeOneField (fieldInfo , mergeState );
424468 FloatVectorValues vectorValues = KnnVectorsWriter .MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
425- // TODO: more efficient way to pass merged vector values to gpuIndex construction
426- KnnVectorValues .DocIndexIterator iter = vectorValues .iterator ();
427- List <float []> vectorList = new ArrayList <>();
428- for (int docV = iter .nextDoc (); docV != NO_MORE_DOCS ; docV = iter .nextDoc ()) {
429- vectorList .add (vectorValues .vectorValue (iter .index ()));
469+ // save merged vector values to a temp file
470+ final int numVectors ;
471+ String tempRawVectorsFileName = null ;
472+ boolean success = false ;
473+ try (IndexOutput out = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "vec_" , IOContext .DEFAULT )) {
474+ tempRawVectorsFileName = out .getName ();
475+ numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState ));
476+ CodecUtil .writeFooter (out );
477+ success = true ;
478+ } finally {
479+ if (success == false && tempRawVectorsFileName != null ) {
480+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
481+ }
482+ }
483+ try (IndexInput in = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT )) {
484+ // TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485+ final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , in , numVectors );
486+ float [][] vectors = new float [numVectors ][fieldInfo .getVectorDimension ()];
487+ float [] vector ;
488+ for (int i = 0 ; i < numVectors ; i ++) {
489+ vector = floatVectorValues .vectorValue (i );
490+ System .arraycopy (vector , 0 , vectors [i ], 0 , vector .length );
491+ }
492+ DatasetOrVectors datasetOrVectors = new DatasetOrVectors (vectors );
493+ writeFieldInternal (fieldInfo , datasetOrVectors );
494+ } finally {
495+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
430496 }
431- float [][] vectors = vectorList .toArray (new float [0 ][]);
497+ }
498+
499+ private static int writeFloatVectorValues (FieldInfo fieldInfo , IndexOutput out , FloatVectorValues floatVectorValues )
500+ throws IOException {
501+ int numVectors = 0 ;
502+ final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
503+ final KnnVectorValues .DocIndexIterator iterator = floatVectorValues .iterator ();
504+ for (int docV = iterator .nextDoc (); docV != NO_MORE_DOCS ; docV = iterator .nextDoc ()) {
505+ numVectors ++;
506+ float [] vector = floatVectorValues .vectorValue (iterator .index ());
507+ out .writeInt (iterator .docID ());
508+ buffer .asFloatBuffer ().put (vector );
509+ out .writeBytes (buffer .array (), buffer .array ().length );
510+ }
511+ return numVectors ;
512+ }
513+
514+ private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput randomAccessInput , int numVectors ) {
515+ if (numVectors == 0 ) {
516+ return FloatVectorValues .fromFloats (List .of (), fieldInfo .getVectorDimension ());
517+ }
518+ final long length = (long ) Float .BYTES * fieldInfo .getVectorDimension () + Integer .BYTES ;
519+ final float [] vector = new float [fieldInfo .getVectorDimension ()];
520+ return new FloatVectorValues () {
521+ @ Override
522+ public float [] vectorValue (int ord ) throws IOException {
523+ randomAccessInput .seek (ord * length + Integer .BYTES );
524+ randomAccessInput .readFloats (vector , 0 , vector .length );
525+ return vector ;
526+ }
527+
528+ @ Override
529+ public FloatVectorValues copy () {
530+ return this ;
531+ }
432532
433- writeFieldInternal (fieldInfo , vectors );
533+ @ Override
534+ public int dimension () {
535+ return fieldInfo .getVectorDimension ();
536+ }
537+
538+ @ Override
539+ public int size () {
540+ return numVectors ;
541+ }
542+
543+ @ Override
544+ public int ordToDoc (int ord ) {
545+ try {
546+ randomAccessInput .seek (ord * length );
547+ return randomAccessInput .readInt ();
548+ } catch (IOException e ) {
549+ throw new UncheckedIOException (e );
550+ }
551+ }
552+ };
434553 }
435554
436555 private void writeMeta (
0 commit comments