1616import org .apache .lucene .codecs .KnnVectorsWriter ;
1717import org .apache .lucene .codecs .hnsw .FlatFieldVectorsWriter ;
1818import org .apache .lucene .codecs .hnsw .FlatVectorsWriter ;
19+ import org .apache .lucene .index .ByteVectorValues ;
1920import org .apache .lucene .index .DocsWithFieldSet ;
2021import org .apache .lucene .index .FieldInfo ;
2122import org .apache .lucene .index .FloatVectorValues ;
3536import org .apache .lucene .util .hnsw .HnswGraph ;
3637import org .apache .lucene .util .hnsw .HnswGraph .NodesIterator ;
3738import org .apache .lucene .util .packed .DirectMonotonicWriter ;
39+ import org .apache .lucene .util .quantization .ScalarQuantizer ;
3840import org .elasticsearch .core .IOUtils ;
3941import org .elasticsearch .core .SuppressForbidden ;
42+ import org .elasticsearch .index .codec .vectors .ES814ScalarQuantizedVectorsFormat ;
4043import org .elasticsearch .logging .LogManager ;
4144import org .elasticsearch .logging .Logger ;
4245
4952import java .util .Objects ;
5053
5154import static org .apache .lucene .codecs .lucene99 .Lucene99HnswVectorsReader .SIMILARITY_FUNCTIONS ;
55+ import static org .apache .lucene .codecs .lucene99 .Lucene99ScalarQuantizedVectorsWriter .mergeAndRecalculateQuantiles ;
5256import static org .apache .lucene .search .DocIdSetIterator .NO_MORE_DOCS ;
5357import static org .elasticsearch .xpack .gpu .codec .ESGpuHnswVectorsFormat .LUCENE99_HNSW_META_CODEC_NAME ;
5458import static org .elasticsearch .xpack .gpu .codec .ESGpuHnswVectorsFormat .LUCENE99_HNSW_META_EXTENSION ;
@@ -75,6 +79,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
7579
7680 private final List <FieldWriter > fields = new ArrayList <>();
7781 private boolean finished ;
82+ private final CuVSMatrix .DataType dataType ;
7883
7984 ESGpuHnswVectorsWriter (
8085 CuVSResourceManager cuVSResourceManager ,
@@ -88,6 +93,11 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
8893 this .M = M ;
8994 this .beamWidth = beamWidth ;
9095 this .flatVectorWriter = flatVectorWriter ;
96+ if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat .ES814ScalarQuantizedVectorsWriter ) {
97+ dataType = CuVSMatrix .DataType .BYTE ;
98+ } else {
99+ dataType = CuVSMatrix .DataType .FLOAT ;
100+ }
91101 this .segmentWriteState = state ;
92102 String metaFileName = IndexFileNames .segmentFileName (state .segmentInfo .name , state .segmentSuffix , LUCENE99_HNSW_META_EXTENSION );
93103 String indexDataFileName = IndexFileNames .segmentFileName (
@@ -411,13 +421,17 @@ public NodesIterator getNodesOnLevel(int level) {
411421 @ SuppressForbidden (reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)" )
412422 public void mergeOneField (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
413423 flatVectorWriter .mergeOneField (fieldInfo , mergeState );
414- // save merged vector values to a temp file
415424 final int numVectors ;
416425 String tempRawVectorsFileName = null ;
417426 boolean success = false ;
427+ // save merged vector values to a temp file
418428 try (IndexOutput out = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "vec_" , IOContext .DEFAULT )) {
419429 tempRawVectorsFileName = out .getName ();
420- numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState ));
430+ if (dataType == CuVSMatrix .DataType .BYTE ) {
431+ numVectors = writeByteVectorValues (out , getMergedByteVectorValues (fieldInfo , mergeState ));
432+ } else {
433+ numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState ));
434+ }
421435 CodecUtil .writeFooter (out );
422436 success = true ;
423437 } finally {
@@ -429,9 +443,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
429443 DatasetOrVectors datasetOrVectors ;
430444 var input = FilterIndexInput .unwrapOnlyTest (in );
431445 if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD ) {
432- var ds = DatasetUtils .getInstance ().fromInput (memorySegmentAccessInput , numVectors , fieldInfo .getVectorDimension ());
446+ var ds = DatasetUtils .getInstance ()
447+ .fromInput (memorySegmentAccessInput , numVectors , fieldInfo .getVectorDimension (), dataType );
433448 datasetOrVectors = DatasetOrVectors .fromDataset (ds );
434449 } else {
450+ // TODO fix for byte vectors
435451 var fa = copyVectorsIntoArray (in , fieldInfo , numVectors );
436452 datasetOrVectors = DatasetOrVectors .fromArray (fa );
437453 }
@@ -441,6 +457,31 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
441457 }
442458 }
443459
460+ private ByteVectorValues getMergedByteVectorValues (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
461+ // TODO: expose confidence interval from the format
462+ final byte bits = 7 ;
463+ final Float confidenceInterval = null ;
464+ ScalarQuantizer quantizer = mergeAndRecalculateQuantiles (mergeState , fieldInfo , confidenceInterval , bits );
465+ MergedQuantizedVectorValues byteVectorValues = MergedQuantizedVectorValues .mergeQuantizedByteVectorValues (
466+ fieldInfo ,
467+ mergeState ,
468+ quantizer
469+ );
470+ return byteVectorValues ;
471+ }
472+
473+ private static int writeByteVectorValues (IndexOutput out , ByteVectorValues vectorValues ) throws IOException {
474+ int numVectors = 0 ;
475+ byte [] vector ;
476+ final KnnVectorValues .DocIndexIterator iterator = vectorValues .iterator ();
477+ for (int docV = iterator .nextDoc (); docV != NO_MORE_DOCS ; docV = iterator .nextDoc ()) {
478+ numVectors ++;
479+ vector = vectorValues .vectorValue (iterator .index ());
480+ out .writeBytes (vector , vector .length );
481+ }
482+ return numVectors ;
483+ }
484+
444485 static float [][] copyVectorsIntoArray (IndexInput in , FieldInfo fieldInfo , int numVectors ) throws IOException {
445486 final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , in , numVectors );
446487 float [][] vectors = new float [numVectors ][fieldInfo .getVectorDimension ()];
0 commit comments