1414import org .apache .lucene .index .FloatVectorValues ;
1515import org .apache .lucene .index .MergeState ;
1616import org .apache .lucene .index .SegmentWriteState ;
17+ import org .apache .lucene .store .IOContext ;
1718import org .apache .lucene .store .IndexInput ;
1819import org .apache .lucene .store .IndexOutput ;
1920import org .apache .lucene .util .VectorUtil ;
21+ import org .apache .lucene .util .hnsw .IntToIntFunction ;
2022import org .elasticsearch .index .codec .vectors .cluster .HierarchicalKMeans ;
2123import org .elasticsearch .index .codec .vectors .cluster .KMeansResult ;
2224import org .elasticsearch .logging .LogManager ;
@@ -49,32 +51,58 @@ long[] buildAndWritePostingsLists(
4951 CentroidSupplier centroidSupplier ,
5052 FloatVectorValues floatVectorValues ,
5153 IndexOutput postingsOutput ,
52- int [][] assignmentsByCluster
54+ int [] assignments ,
55+ int [] overspillAssignments
5356 ) throws IOException {
57+ int [] centroidVectorCount = new int [centroidSupplier .size ()];
58+ for (int i = 0 ; i < assignments .length ; i ++) {
59+ centroidVectorCount [assignments [i ]]++;
60+ // if soar assignments are present, count them as well
61+ if (overspillAssignments .length > i && overspillAssignments [i ] != -1 ) {
62+ centroidVectorCount [overspillAssignments [i ]]++;
63+ }
64+ }
65+
66+ int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
67+ for (int c = 0 ; c < centroidSupplier .size (); c ++) {
68+ assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
69+ }
70+ Arrays .fill (centroidVectorCount , 0 );
71+
72+ for (int i = 0 ; i < assignments .length ; i ++) {
73+ int c = assignments [i ];
74+ assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
75+ // if soar assignments are present, add them to the cluster as well
76+ if (overspillAssignments .length > i ) {
77+ int s = overspillAssignments [i ];
78+ if (s != -1 ) {
79+ assignmentsByCluster [s ][centroidVectorCount [s ]++] = i ;
80+ }
81+ }
82+ }
5483 // write the posting lists
5584 final long [] offsets = new long [centroidSupplier .size ()];
56- OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
5785 DocIdsWriter docIdsWriter = new DocIdsWriter ();
58- DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (
59- ES91OSQVectorsScorer .BULK_SIZE ,
60- quantizer ,
86+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
87+ OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors (
6188 floatVectorValues ,
62- postingsOutput
89+ fieldInfo .getVectorDimension (),
90+ new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ())
6391 );
6492 for (int c = 0 ; c < centroidSupplier .size (); c ++) {
6593 float [] centroid = centroidSupplier .centroid (c );
66- // TODO: add back in sorting vectors by distance to centroid
6794 int [] cluster = assignmentsByCluster [c ];
6895 // TODO align???
6996 offsets [c ] = postingsOutput .getFilePointer ();
7097 int size = cluster .length ;
7198 postingsOutput .writeVInt (size );
7299 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
100+ onHeapQuantizedVectors .reset (centroid , size , ord -> cluster [ord ]);
73101 // TODO we might want to consider putting the docIds in a separate file
74102 // to aid with only having to fetch vectors from slower storage when they are required
75103 // keeping them in the same file indicates we pull the entire file into cache
76104 docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [j ]), size , postingsOutput );
77- bulkWriter .writeOrds ( j -> cluster [ j ], cluster . length , centroid );
105+ bulkWriter .writeVectors ( onHeapQuantizedVectors );
78106 }
79107
80108 if (logger .isDebugEnabled ()) {
@@ -84,6 +112,124 @@ long[] buildAndWritePostingsLists(
84112 return offsets ;
85113 }
86114
115+ @ Override
116+ long [] buildAndWritePostingsLists (
117+ FieldInfo fieldInfo ,
118+ CentroidSupplier centroidSupplier ,
119+ FloatVectorValues floatVectorValues ,
120+ IndexOutput postingsOutput ,
121+ MergeState mergeState ,
122+ int [] assignments ,
123+ int [] overspillAssignments
124+ ) throws IOException {
125+ // first, quantize all the vectors into a temporary file
126+ String quantizedVectorsTempName = null ;
127+ IndexOutput quantizedVectorsTemp = null ;
128+ boolean success = false ;
129+ try {
130+ quantizedVectorsTemp = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "qvec_" , IOContext .DEFAULT );
131+ quantizedVectorsTempName = quantizedVectorsTemp .getName ();
132+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
133+ int [] quantized = new int [fieldInfo .getVectorDimension ()];
134+ byte [] binary = new byte [BQVectorUtils .discretize (fieldInfo .getVectorDimension (), 64 ) / 8 ];
135+ float [] overspillScratch = new float [fieldInfo .getVectorDimension ()];
136+ for (int i = 0 ; i < assignments .length ; i ++) {
137+ int c = assignments [i ];
138+ float [] centroid = centroidSupplier .centroid (c );
139+ float [] vector = floatVectorValues .vectorValue (i );
140+ boolean overspill = overspillAssignments .length > i && overspillAssignments [i ] != -1 ;
141+ // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
142+ // so, make a copy of the vector to avoid mutating it
143+ if (overspill ) {
144+ System .arraycopy (vector , 0 , overspillScratch , 0 , fieldInfo .getVectorDimension ());
145+ }
146+
147+ OptimizedScalarQuantizer .QuantizationResult result = quantizer .scalarQuantize (vector , quantized , (byte ) 1 , centroid );
148+ BQVectorUtils .packAsBinary (quantized , binary );
149+ writeQuantizedValue (quantizedVectorsTemp , binary , result );
150+ if (overspill ) {
151+ int s = overspillAssignments [i ];
152+ // write the overspill vector as well
153+ result = quantizer .scalarQuantize (overspillScratch , quantized , (byte ) 1 , centroidSupplier .centroid (s ));
154+ BQVectorUtils .packAsBinary (quantized , binary );
155+ writeQuantizedValue (quantizedVectorsTemp , binary , result );
156+ } else {
157+ // write a zero vector for the overspill
158+ Arrays .fill (binary , (byte ) 0 );
159+ OptimizedScalarQuantizer .QuantizationResult zeroResult = new OptimizedScalarQuantizer .QuantizationResult (0f , 0f , 0f , 0 );
160+ writeQuantizedValue (quantizedVectorsTemp , binary , zeroResult );
161+ }
162+ }
163+ // close the temporary file so we can read it later
164+ quantizedVectorsTemp .close ();
165+ success = true ;
166+ } finally {
167+ if (success == false && quantizedVectorsTemp != null ) {
168+ mergeState .segmentInfo .dir .deleteFile (quantizedVectorsTemp .getName ());
169+ }
170+ }
171+ int [] centroidVectorCount = new int [centroidSupplier .size ()];
172+ for (int i = 0 ; i < assignments .length ; i ++) {
173+ centroidVectorCount [assignments [i ]]++;
174+ // if soar assignments are present, count them as well
175+ if (overspillAssignments .length > i && overspillAssignments [i ] != -1 ) {
176+ centroidVectorCount [overspillAssignments [i ]]++;
177+ }
178+ }
179+
180+ int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
181+ boolean [][] isOverspillByCluster = new boolean [centroidSupplier .size ()][];
182+ for (int c = 0 ; c < centroidSupplier .size (); c ++) {
183+ assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
184+ isOverspillByCluster [c ] = new boolean [centroidVectorCount [c ]];
185+ }
186+ Arrays .fill (centroidVectorCount , 0 );
187+
188+ for (int i = 0 ; i < assignments .length ; i ++) {
189+ int c = assignments [i ];
190+ assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
191+ // if soar assignments are present, add them to the cluster as well
192+ if (overspillAssignments .length > i ) {
193+ int s = overspillAssignments [i ];
194+ if (s != -1 ) {
195+ assignmentsByCluster [s ][centroidVectorCount [s ]] = i ;
196+ isOverspillByCluster [s ][centroidVectorCount [s ]++] = true ;
197+ }
198+ }
199+ }
200+ // now we can read the quantized vectors from the temporary file
201+ try (IndexInput quantizedVectorsInput = mergeState .segmentInfo .dir .openInput (quantizedVectorsTempName , IOContext .DEFAULT )) {
202+ final long [] offsets = new long [centroidSupplier .size ()];
203+ OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors (
204+ quantizedVectorsInput ,
205+ fieldInfo .getVectorDimension ()
206+ );
207+ DocIdsWriter docIdsWriter = new DocIdsWriter ();
208+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
209+ for (int c = 0 ; c < centroidSupplier .size (); c ++) {
210+ float [] centroid = centroidSupplier .centroid (c );
211+ int [] cluster = assignmentsByCluster [c ];
212+ boolean [] isOverspill = isOverspillByCluster [c ];
213+ // TODO align???
214+ offsets [c ] = postingsOutput .getFilePointer ();
215+ int size = cluster .length ;
216+ postingsOutput .writeVInt (size );
217+ postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
218+ offHeapQuantizedVectors .reset (size , ord -> isOverspill [ord ], ord -> cluster [ord ]);
219+ // TODO we might want to consider putting the docIds in a separate file
220+ // to aid with only having to fetch vectors from slower storage when they are required
221+ // keeping them in the same file indicates we pull the entire file into cache
222+ docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [j ]), size , postingsOutput );
223+ bulkWriter .writeVectors (offHeapQuantizedVectors );
224+ }
225+
226+ if (logger .isDebugEnabled ()) {
227+ printClusterQualityStatistics (assignmentsByCluster );
228+ }
229+ return offsets ;
230+ }
231+ }
232+
87233 private static void printClusterQualityStatistics (int [][] clusters ) {
88234 float min = Float .MAX_VALUE ;
89235 float max = Float .MIN_VALUE ;
@@ -210,33 +356,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) {
210356 float [][] centroids = kMeansResult .centroids ();
211357 int [] assignments = kMeansResult .assignments ();
212358 int [] soarAssignments = kMeansResult .soarAssignments ();
213- int [] centroidVectorCount = new int [centroids .length ];
214- for (int i = 0 ; i < assignments .length ; i ++) {
215- centroidVectorCount [assignments [i ]]++;
216- // if soar assignments are present, count them as well
217- if (soarAssignments .length > i && soarAssignments [i ] != -1 ) {
218- centroidVectorCount [soarAssignments [i ]]++;
219- }
220- }
221-
222- int [][] assignmentsByCluster = new int [centroids .length ][];
223- for (int c = 0 ; c < centroids .length ; c ++) {
224- assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
225- }
226- Arrays .fill (centroidVectorCount , 0 );
227-
228- for (int i = 0 ; i < assignments .length ; i ++) {
229- int c = assignments [i ];
230- assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
231- // if soar assignments are present, add them to the cluster as well
232- if (soarAssignments .length > i ) {
233- int s = soarAssignments [i ];
234- if (s != -1 ) {
235- assignmentsByCluster [s ][centroidVectorCount [s ]++] = i ;
236- }
237- }
238- }
239- return new CentroidAssignments (centroids , assignmentsByCluster );
359+ return new CentroidAssignments (centroids , assignments , soarAssignments );
240360 }
241361
242362 static void writeQuantizedValue (IndexOutput indexOutput , byte [] binaryValue , OptimizedScalarQuantizer .QuantizationResult corrections )
@@ -281,4 +401,132 @@ public float[] centroid(int centroidOrdinal) throws IOException {
281401 return scratch ;
282402 }
283403 }
404+
405+ interface QuantizedVectorValues {
406+ int count ();
407+
408+ byte [] next () throws IOException ;
409+
410+ OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException ;
411+ }
412+
413+ interface IntToBooleanFunction {
414+ boolean apply (int ord );
415+ }
416+
417+ static class OnHeapQuantizedVectors implements QuantizedVectorValues {
418+ private final FloatVectorValues vectorValues ;
419+ private final OptimizedScalarQuantizer quantizer ;
420+ private final byte [] quantizedVector ;
421+ private final int [] quantizedVectorScratch ;
422+ private OptimizedScalarQuantizer .QuantizationResult corrections ;
423+ private float [] currentCentroid ;
424+ private IntToIntFunction ordTransformer = null ;
425+ private int currOrd = -1 ;
426+ private int count ;
427+
428+ OnHeapQuantizedVectors (FloatVectorValues vectorValues , int dimension , OptimizedScalarQuantizer quantizer ) {
429+ this .vectorValues = vectorValues ;
430+ this .quantizer = quantizer ;
431+ this .quantizedVector = new byte [BQVectorUtils .discretize (dimension , 64 ) / 8 ];
432+ this .quantizedVectorScratch = new int [dimension ];
433+ this .corrections = null ;
434+ }
435+
436+ private void reset (float [] centroid , int count , IntToIntFunction ordTransformer ) {
437+ this .currentCentroid = centroid ;
438+ this .ordTransformer = ordTransformer ;
439+ this .currOrd = -1 ;
440+ this .count = count ;
441+ }
442+
443+ @ Override
444+ public int count () {
445+ return count ;
446+ }
447+
448+ @ Override
449+ public byte [] next () throws IOException {
450+ if (currOrd >= count () - 1 ) {
451+ throw new IllegalStateException ("No more vectors to read, current ord: " + currOrd + ", count: " + count ());
452+ }
453+ currOrd ++;
454+ int ord = ordTransformer .apply (currOrd );
455+ float [] vector = vectorValues .vectorValue (ord );
456+ corrections = quantizer .scalarQuantize (vector , quantizedVectorScratch , (byte ) 1 , currentCentroid );
457+ BQVectorUtils .packAsBinary (quantizedVectorScratch , quantizedVector );
458+ return quantizedVector ;
459+ }
460+
461+ @ Override
462+ public OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException {
463+ if (currOrd == -1 ) {
464+ throw new IllegalStateException ("No vector read yet, call next first" );
465+ }
466+ return corrections ;
467+ }
468+ }
469+
470+ static class OffHeapQuantizedVectors implements QuantizedVectorValues {
471+ private final IndexInput quantizedVectorsInput ;
472+ private final byte [] binaryScratch ;
473+ private final float [] corrections = new float [3 ];
474+
475+ private final int vectorByteSize ;
476+ private short bitSum ;
477+ private int currOrd = -1 ;
478+ private int count ;
479+ private IntToBooleanFunction isOverspill = null ;
480+ private IntToIntFunction ordTransformer = null ;
481+
482+ OffHeapQuantizedVectors (IndexInput quantizedVectorsInput , int dimension ) {
483+ this .quantizedVectorsInput = quantizedVectorsInput ;
484+ this .binaryScratch = new byte [BQVectorUtils .discretize (dimension , 64 ) / 8 ];
485+ this .vectorByteSize = (binaryScratch .length + 3 * Float .BYTES + Short .BYTES );
486+ }
487+
488+ private void reset (int count , IntToBooleanFunction isOverspill , IntToIntFunction ordTransformer ) {
489+ this .count = count ;
490+ this .isOverspill = isOverspill ;
491+ this .ordTransformer = ordTransformer ;
492+ this .currOrd = -1 ;
493+ }
494+
495+ @ Override
496+ public int count () {
497+ return count ;
498+ }
499+
500+ @ Override
501+ public byte [] next () throws IOException {
502+ if (currOrd >= count - 1 ) {
503+ throw new IllegalStateException ("No more vectors to read, current ord: " + currOrd + ", count: " + count );
504+ }
505+ currOrd ++;
506+ int ord = ordTransformer .apply (currOrd );
507+ boolean isOverspill = this .isOverspill .apply (currOrd );
508+ return getVector (ord , isOverspill );
509+ }
510+
511+ @ Override
512+ public OptimizedScalarQuantizer .QuantizationResult getCorrections () throws IOException {
513+ if (currOrd == -1 ) {
514+ throw new IllegalStateException ("No vector read yet, call readQuantizedVector first" );
515+ }
516+ return new OptimizedScalarQuantizer .QuantizationResult (corrections [0 ], corrections [1 ], corrections [2 ], bitSum );
517+ }
518+
519+ byte [] getVector (int ord , boolean isOverspill ) throws IOException {
520+ readQuantizedVector (ord , isOverspill );
521+ return binaryScratch ;
522+ }
523+
524+ public void readQuantizedVector (int ord , boolean isOverspill ) throws IOException {
525+ long offset = (long ) ord * (vectorByteSize * 2L ) + (isOverspill ? vectorByteSize : 0 );
526+ quantizedVectorsInput .seek (offset );
527+ quantizedVectorsInput .readBytes (binaryScratch , 0 , binaryScratch .length );
528+ quantizedVectorsInput .readFloats (corrections , 0 , 3 );
529+ bitSum = quantizedVectorsInput .readShort ();
530+ }
531+ }
284532}
0 commit comments