2828import java .nio .ByteOrder ;
2929import java .util .Arrays ;
3030
31- import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .INDEX_BITS ;
32- import static org .elasticsearch .index .codec .vectors .BQVectorUtils .discretize ;
33- import static org .elasticsearch .index .codec .vectors .BQVectorUtils .packAsBinary ;
34-
3531/**
3632 * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
3733 * partition the vector space, and then stores the centroids and posting list in a sequential
@@ -58,12 +54,15 @@ long[] buildAndWritePostingsLists(
5854 // write the posting lists
5955 final long [] offsets = new long [centroidSupplier .size ()];
6056 OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
61- BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues (floatVectorValues , quantizer );
6257 DocIdsWriter docIdsWriter = new DocIdsWriter ();
63-
58+ DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (
59+ ES91OSQVectorsScorer .BULK_SIZE ,
60+ quantizer ,
61+ floatVectorValues ,
62+ postingsOutput
63+ );
6464 for (int c = 0 ; c < centroidSupplier .size (); c ++) {
6565 float [] centroid = centroidSupplier .centroid (c );
66- binarizedByteVectorValues .centroid = centroid ;
6766 // TODO: add back in sorting vectors by distance to centroid
6867 int [] cluster = assignmentsByCluster [c ];
6968 // TODO align???
@@ -75,7 +74,7 @@ long[] buildAndWritePostingsLists(
7574 // to aid with only having to fetch vectors from slower storage when they are required
7675 // keeping them in the same file indicates we pull the entire file into cache
7776 docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [j ]), size , postingsOutput );
78- writePostingList ( cluster , postingsOutput , binarizedByteVectorValues );
77+ bulkWriter . writeOrds ( j -> cluster [ j ], cluster . length , centroid );
7978 }
8079
8180 if (logger .isDebugEnabled ()) {
@@ -115,54 +114,6 @@ private static void printClusterQualityStatistics(int[][] clusters) {
115114 );
116115 }
117116
118- private void writePostingList (int [] cluster , IndexOutput postingsOutput , BinarizedFloatVectorValues binarizedByteVectorValues )
119- throws IOException {
120- int limit = cluster .length - ES91OSQVectorsScorer .BULK_SIZE + 1 ;
121- int cidx = 0 ;
122- OptimizedScalarQuantizer .QuantizationResult [] corrections =
123- new OptimizedScalarQuantizer .QuantizationResult [ES91OSQVectorsScorer .BULK_SIZE ];
124- // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
125- for (; cidx < limit ; cidx += ES91OSQVectorsScorer .BULK_SIZE ) {
126- for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
127- int ord = cluster [cidx + j ];
128- byte [] binaryValue = binarizedByteVectorValues .vectorValue (ord );
129- // write vector
130- postingsOutput .writeBytes (binaryValue , 0 , binaryValue .length );
131- corrections [j ] = binarizedByteVectorValues .getCorrectiveTerms (ord );
132- }
133- // write corrections
134- for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
135- postingsOutput .writeInt (Float .floatToIntBits (corrections [j ].lowerInterval ()));
136- }
137- for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
138- postingsOutput .writeInt (Float .floatToIntBits (corrections [j ].upperInterval ()));
139- }
140- for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
141- int targetComponentSum = corrections [j ].quantizedComponentSum ();
142- assert targetComponentSum >= 0 && targetComponentSum <= 0xffff ;
143- postingsOutput .writeShort ((short ) targetComponentSum );
144- }
145- for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
146- postingsOutput .writeInt (Float .floatToIntBits (corrections [j ].additionalCorrection ()));
147- }
148- }
149- // write tail
150- for (; cidx < cluster .length ; cidx ++) {
151- int ord = cluster [cidx ];
152- // write vector
153- byte [] binaryValue = binarizedByteVectorValues .vectorValue (ord );
154- OptimizedScalarQuantizer .QuantizationResult correction = binarizedByteVectorValues .getCorrectiveTerms (ord );
155- writeQuantizedValue (postingsOutput , binaryValue , correction );
156- binarizedByteVectorValues .getCorrectiveTerms (ord );
157- postingsOutput .writeBytes (binaryValue , 0 , binaryValue .length );
158- postingsOutput .writeInt (Float .floatToIntBits (correction .lowerInterval ()));
159- postingsOutput .writeInt (Float .floatToIntBits (correction .upperInterval ()));
160- postingsOutput .writeInt (Float .floatToIntBits (correction .additionalCorrection ()));
161- assert correction .quantizedComponentSum () >= 0 && correction .quantizedComponentSum () <= 0xffff ;
162- postingsOutput .writeShort ((short ) correction .quantizedComponentSum ());
163- }
164- }
165-
166117 @ Override
167118 CentroidSupplier createCentroidSupplier (IndexInput centroidsInput , int numCentroids , FieldInfo fieldInfo , float [] globalCentroid ) {
168119 return new OffHeapCentroidSupplier (centroidsInput , numCentroids , fieldInfo );
@@ -295,47 +246,6 @@ CentroidAssignments calculateAndWriteCentroids(
295246 }
296247 }
297248
298- // TODO unify with OSQ format
299- static class BinarizedFloatVectorValues {
300- private OptimizedScalarQuantizer .QuantizationResult corrections ;
301- private final byte [] binarized ;
302- private final byte [] initQuantized ;
303- private float [] centroid ;
304- private final FloatVectorValues values ;
305- private final OptimizedScalarQuantizer quantizer ;
306-
307- private int lastOrd = -1 ;
308-
309- BinarizedFloatVectorValues (FloatVectorValues delegate , OptimizedScalarQuantizer quantizer ) {
310- this .values = delegate ;
311- this .quantizer = quantizer ;
312- this .binarized = new byte [discretize (delegate .dimension (), 64 ) / 8 ];
313- this .initQuantized = new byte [delegate .dimension ()];
314- }
315-
316- public OptimizedScalarQuantizer .QuantizationResult getCorrectiveTerms (int ord ) {
317- if (ord != lastOrd ) {
318- throw new IllegalStateException (
319- "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
320- );
321- }
322- return corrections ;
323- }
324-
325- public byte [] vectorValue (int ord ) throws IOException {
326- if (ord != lastOrd ) {
327- binarize (ord );
328- lastOrd = ord ;
329- }
330- return binarized ;
331- }
332-
333- private void binarize (int ord ) throws IOException {
334- corrections = quantizer .scalarQuantize (values .vectorValue (ord ), initQuantized , INDEX_BITS , centroid );
335- packAsBinary (initQuantized , binarized );
336- }
337- }
338-
339249 static void writeQuantizedValue (IndexOutput indexOutput , byte [] binaryValue , OptimizedScalarQuantizer .QuantizationResult corrections )
340250 throws IOException {
341251 indexOutput .writeBytes (binaryValue , binaryValue .length );
0 commit comments