1818import org .apache .lucene .store .IndexInput ;
1919import org .apache .lucene .store .IndexOutput ;
2020import org .apache .lucene .util .VectorUtil ;
21+ import org .apache .lucene .util .hnsw .IntToIntFunction ;
2122import org .elasticsearch .index .codec .vectors .cluster .HierarchicalKMeans ;
2223import org .elasticsearch .index .codec .vectors .cluster .KMeansResult ;
2324import org .elasticsearch .logging .LogManager ;
@@ -81,28 +82,27 @@ long[] buildAndWritePostingsLists(
8182 }
8283 // write the posting lists
8384 final long [] offsets = new long [centroidSupplier .size ()];
84- OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
8585 DocIdsWriter docIdsWriter = new DocIdsWriter ();
86- DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (
87- ES91OSQVectorsScorer .BULK_SIZE ,
88- quantizer ,
86+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
87+ OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors (
8988 floatVectorValues ,
90- postingsOutput
89+ fieldInfo .getVectorDimension (),
90+ new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ())
9191 );
9292 for (int c = 0 ; c < centroidSupplier .size (); c ++) {
9393 float [] centroid = centroidSupplier .centroid (c );
94- // TODO: add back in sorting vectors by distance to centroid
9594 int [] cluster = assignmentsByCluster [c ];
9695 // TODO align???
9796 offsets [c ] = postingsOutput .getFilePointer ();
9897 int size = cluster .length ;
9998 postingsOutput .writeVInt (size );
10099 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
100+ onHeapQuantizedVectors .reset (centroid , size , ord -> cluster [ord ]);
101101 // TODO we might want to consider putting the docIds in a separate file
102102 // to aid with only having to fetch vectors from slower storage when they are required
103103 // keeping them in the same file indicates we pull the entire file into cache
104104 docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [j ]), size , postingsOutput );
105- bulkWriter .writeOrds ( j -> cluster [ j ], cluster . length , centroid );
105+ bulkWriter .writeVectors ( onHeapQuantizedVectors );
106106 }
107107
108108 if (logger .isDebugEnabled ()) {
@@ -161,6 +161,35 @@ long[] buildAndWritePostingsLists(
161161 mergeState .segmentInfo .dir .deleteFile (quantizedVectorsTemp .getName ());
162162 }
163163 }
164+ int [] centroidVectorCount = new int [centroidSupplier .size ()];
165+ for (int i = 0 ; i < assignments .length ; i ++) {
166+ centroidVectorCount [assignments [i ]]++;
167+ // if soar assignments are present, count them as well
168+ if (overspillAssignments .length > i && overspillAssignments [i ] != -1 ) {
169+ centroidVectorCount [overspillAssignments [i ]]++;
170+ }
171+ }
172+
173+ int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
174+ boolean [][] isOverspillByCluster = new boolean [centroidSupplier .size ()][];
175+ for (int c = 0 ; c < centroidSupplier .size (); c ++) {
176+ assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
177+ isOverspillByCluster [c ] = new boolean [centroidVectorCount [c ]];
178+ }
179+ Arrays .fill (centroidVectorCount , 0 );
180+
181+ for (int i = 0 ; i < assignments .length ; i ++) {
182+ int c = assignments [i ];
183+ assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
184+ // if soar assignments are present, add them to the cluster as well
185+ if (overspillAssignments .length > i ) {
186+ int s = overspillAssignments [i ];
187+ if (s != -1 ) {
188+ assignmentsByCluster [s ][centroidVectorCount [s ]] = i ;
189+ isOverspillByCluster [s ][centroidVectorCount [s ]++] = true ;
190+ }
191+ }
192+ }
164193 // now we can read the quantized vectors from the temporary file
165194 try (IndexInput quantizedVectorsInput = mergeState .segmentInfo .dir .openInput (quantizedVectorsTempName , IOContext .DEFAULT )) {
166195 final long [] offsets = new long [centroidSupplier .size ()];
@@ -169,26 +198,22 @@ long[] buildAndWritePostingsLists(
169198 fieldInfo .getVectorDimension ()
170199 );
171200 DocIdsWriter docIdsWriter = new DocIdsWriter ();
172- DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (
173- ES91OSQVectorsScorer .BULK_SIZE ,
174- quantizer ,
175- floatVectorValues ,
176- postingsOutput
177- );
201+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
178202 for (int c = 0 ; c < centroidSupplier .size (); c ++) {
179203 float [] centroid = centroidSupplier .centroid (c );
180- // TODO: add back in sorting vectors by distance to centroid
181204 int [] cluster = assignmentsByCluster [c ];
205+ boolean [] isOverspill = isOverspillByCluster [c ];
182206 // TODO align???
183207 offsets [c ] = postingsOutput .getFilePointer ();
184208 int size = cluster .length ;
185209 postingsOutput .writeVInt (size );
186210 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
211+ offHeapQuantizedVectors .reset (size , ord -> isOverspill [ord ], ord -> cluster [ord ]);
187212 // TODO we might want to consider putting the docIds in a separate file
188213 // to aid with only having to fetch vectors from slower storage when they are required
189214 // keeping them in the same file indicates we pull the entire file into cache
190215 docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [j ]), size , postingsOutput );
191- bulkWriter .writeOrds ( j -> cluster [ j ], cluster . length , centroid );
216+ bulkWriter .writeVectors ( offHeapQuantizedVectors );
192217 }
193218
194219 if (logger .isDebugEnabled ()) {
@@ -370,47 +395,131 @@ public float[] centroid(int centroidOrdinal) throws IOException {
370395 }
371396 }
372397
373- static class OffHeapQuantizedVectors {
398+ interface QuantizedVectorValues {
399+ int count ();
400+
401+ byte [] next () throws IOException ;
402+
403+ OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException ;
404+ }
405+
406+ interface IntToBooleanFunction {
407+ boolean apply (int ord );
408+ }
409+
410+ static class OnHeapQuantizedVectors implements QuantizedVectorValues {
411+ private final FloatVectorValues vectorValues ;
412+ private final OptimizedScalarQuantizer quantizer ;
413+ private final byte [] quantizedVector ;
414+ private final int [] quantizedVectorScratch ;
415+ private OptimizedScalarQuantizer .QuantizationResult corrections ;
416+ private float [] currentCentroid ;
417+ private IntToIntFunction ordTransformer = null ;
418+ private int currOrd = -1 ;
419+ private int count ;
420+
421+ OnHeapQuantizedVectors (FloatVectorValues vectorValues , int dimension , OptimizedScalarQuantizer quantizer ) {
422+ this .vectorValues = vectorValues ;
423+ this .quantizer = quantizer ;
424+ this .quantizedVector = new byte [BQVectorUtils .discretize (dimension , 64 ) / 8 ];
425+ this .quantizedVectorScratch = new int [dimension ];
426+ this .corrections = null ;
427+ }
428+
429+ private void reset (float [] centroid , int count , IntToIntFunction ordTransformer ) {
430+ this .currentCentroid = centroid ;
431+ this .ordTransformer = ordTransformer ;
432+ this .currOrd = -1 ;
433+ this .count = count ;
434+ }
435+
436+ @ Override
437+ public int count () {
438+ return count ;
439+ }
440+
441+ @ Override
442+ public byte [] next () throws IOException {
443+ if (currOrd >= count () - 1 ) {
444+ throw new IllegalStateException ("No more vectors to read, current ord: " + currOrd + ", count: " + count ());
445+ }
446+ currOrd ++;
447+ int ord = ordTransformer .apply (currOrd );
448+ float [] vector = vectorValues .vectorValue (ord );
449+ corrections = quantizer .scalarQuantize (vector , quantizedVectorScratch , (byte ) 1 , currentCentroid );
450+ BQVectorUtils .packAsBinary (quantizedVectorScratch , quantizedVector );
451+ return quantizedVector ;
452+ }
453+
454+ @ Override
455+ public OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException {
456+ if (currOrd == -1 ) {
457+ throw new IllegalStateException ("No vector read yet, call next first" );
458+ }
459+ return corrections ;
460+ }
461+ }
462+
463+ static class OffHeapQuantizedVectors implements QuantizedVectorValues {
374464 private final IndexInput quantizedVectorsInput ;
375465 private final byte [] binaryScratch ;
376466 private final float [] corrections = new float [3 ];
377467
378468 private final int vectorByteSize ;
379469 private short bitSum ;
380470 private int currOrd = -1 ;
381- private boolean isOverspill = false ;
471+ private int count ;
472+ private IntToBooleanFunction isOverspill = null ;
473+ private IntToIntFunction ordTransformer = null ;
382474
383475 OffHeapQuantizedVectors (IndexInput quantizedVectorsInput , int dimension ) {
384476 this .quantizedVectorsInput = quantizedVectorsInput ;
385477 this .binaryScratch = new byte [BQVectorUtils .discretize (dimension , 64 ) / 8 ];
386478 this .vectorByteSize = (binaryScratch .length + 3 * Float .BYTES + Short .BYTES );
387479 }
388480
389- byte [] getVector (int ord , boolean isOverspill ) throws IOException {
390- readQuantizedVector (ord , isOverspill );
391- return binaryScratch ;
481+ private void reset (int count , IntToBooleanFunction isOverspill , IntToIntFunction ordTransformer ) {
482+ this .count = count ;
483+ this .isOverspill = isOverspill ;
484+ this .ordTransformer = ordTransformer ;
485+ this .currOrd = -1 ;
392486 }
393487
394- OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException {
488+ @ Override
489+ public int count () {
490+ return count ;
491+ }
492+
493+ @ Override
494+ public byte [] next () throws IOException {
495+ if (currOrd >= count - 1 ) {
496+ throw new IllegalStateException ("No more vectors to read, current ord: " + currOrd + ", count: " + count );
497+ }
498+ currOrd ++;
499+ int ord = ordTransformer .apply (currOrd );
500+ boolean isOverspill = this .isOverspill .apply (currOrd );
501+ return getVector (ord , isOverspill );
502+ }
503+
504+ @ Override
505+ public OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException {
395506 if (currOrd == -1 ) {
396507 throw new IllegalStateException ("No vector read yet, call readQuantizedVector first" );
397508 }
398509 return new OptimizedScalarQuantizer .QuantizationResult (corrections [0 ], corrections [1 ], corrections [2 ], bitSum );
399510 }
400511
512+ byte [] getVector (int ord , boolean isOverspill ) throws IOException {
513+ readQuantizedVector (ord , isOverspill );
514+ return binaryScratch ;
515+ }
516+
401517 public void readQuantizedVector (int ord , boolean isOverspill ) throws IOException {
402- if (ord == currOrd && isOverspill == this .isOverspill ) {
403- return ; // no need to read again
404- }
405- long offset = (long ) ord * (vectorByteSize * 2 ) + (isOverspill ? vectorByteSize : 0 );
518+ long offset = (long ) ord * (vectorByteSize * 2L ) + (isOverspill ? vectorByteSize : 0 );
406519 quantizedVectorsInput .seek (offset );
407520 quantizedVectorsInput .readBytes (binaryScratch , 0 , binaryScratch .length );
408521 quantizedVectorsInput .readFloats (corrections , 0 , 3 );
409522 bitSum = quantizedVectorsInput .readShort ();
410- if (ord != currOrd ) {
411- currOrd = ord ;
412- }
413- this .isOverspill = isOverspill ;
414523 }
415524 }
416525}
0 commit comments