1717import org .apache .lucene .store .IOContext ;
1818import org .apache .lucene .store .IndexInput ;
1919import org .apache .lucene .store .IndexOutput ;
20- import org .apache .lucene .util .IntroSorter ;
2120import org .apache .lucene .util .LongValues ;
2221import org .apache .lucene .util .VectorUtil ;
2322import org .apache .lucene .util .hnsw .IntToIntFunction ;
@@ -49,28 +48,57 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec
4948 this .vectorPerCluster = vectorPerCluster ;
5049 }
5150
52- @ Override
53- LongValues buildAndWritePostingsLists (
54- FieldInfo fieldInfo ,
55- CentroidSupplier centroidSupplier ,
56- FloatVectorValues floatVectorValues ,
57- IndexOutput postingsOutput ,
51+ private static void deltaEncode (int [] vals , int size , int [] deltas ) {
52+ if (size == 0 ) {
53+ return ;
54+ }
55+ deltas [0 ] = vals [0 ];
56+ for (int i = 1 ; i < size ; i ++) {
57+ assert vals [i ] >= vals [i - 1 ] : "vals are not sorted: " + vals [i ] + " < " + vals [i - 1 ];
58+ deltas [i ] = vals [i ] - vals [i - 1 ];
59+ }
60+ }
61+
62+ private static void translateOrdsToDocs (
63+ int [] ords ,
64+ int size ,
65+ int [] spillOrds ,
66+ int spillSize ,
67+ int [] docIds ,
68+ int [] spillDocIds ,
69+ IntToIntFunction ordToDoc
70+ ) {
71+ int ordIdx = 0 , spillOrdIdx = 0 ;
72+ while (ordIdx < size || spillOrdIdx < spillSize ) {
73+ int nextOrd = (ordIdx < size ) ? ords [ordIdx ] : Integer .MAX_VALUE ;
74+ int nextSpillOrd = (spillOrdIdx < spillSize ) ? spillOrds [spillOrdIdx ] : Integer .MAX_VALUE ;
75+ if (nextOrd < nextSpillOrd ) {
76+ docIds [ordIdx ] = ordToDoc .apply (nextOrd );
77+ ordIdx ++;
78+ } else {
79+ spillDocIds [spillOrdIdx ] = ordToDoc .apply (nextSpillOrd );
80+ spillOrdIdx ++;
81+ }
82+ }
83+ }
84+
85+ private static void pivotAssignments (
86+ int centroidCount ,
5887 int [] assignments ,
59- int [] overspillAssignments
60- ) throws IOException {
61- int [] centroidVectorCount = new int [centroidSupplier .size ()];
62- int [] overspillVectorCount = new int [centroidSupplier .size ()];
88+ int [] overspillAssignments ,
89+ int [][] assignmentsByCluster ,
90+ int [][] overspillAssignmentsByCluster
91+ ) {
92+ int [] centroidVectorCount = new int [centroidCount ];
93+ int [] overspillVectorCount = new int [centroidCount ];
6394 for (int i = 0 ; i < assignments .length ; i ++) {
6495 centroidVectorCount [assignments [i ]]++;
6596 // if soar assignments are present, count them as well
6697 if (overspillAssignments .length > i && overspillAssignments [i ] != -1 ) {
6798 overspillVectorCount [overspillAssignments [i ]]++;
6899 }
69100 }
70-
71- int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
72- int [][] overspillAssignmentsByCluster = new int [centroidSupplier .size ()][];
73- for (int c = 0 ; c < centroidSupplier .size (); c ++) {
101+ for (int c = 0 ; c < centroidCount ; c ++) {
74102 assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
75103 overspillAssignmentsByCluster [c ] = new int [overspillVectorCount [c ]];
76104 }
@@ -88,14 +116,35 @@ LongValues buildAndWritePostingsLists(
88116 }
89117 }
90118 }
119+ }
120+
121+ @ Override
122+ LongValues buildAndWritePostingsLists (
123+ FieldInfo fieldInfo ,
124+ CentroidSupplier centroidSupplier ,
125+ FloatVectorValues floatVectorValues ,
126+ IndexOutput postingsOutput ,
127+ int [] assignments ,
128+ int [] overspillAssignments
129+ ) throws IOException {
130+
91131 // write the posting lists
92132 final PackedLongValues .Builder offsets = PackedLongValues .monotonicBuilder (PackedInts .COMPACT );
93133 DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
134+ // pivot the assignments into clusters
135+ int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
136+ int [][] overspillAssignmentsByCluster = new int [centroidSupplier .size ()][];
137+ pivotAssignments (centroidSupplier .size (), assignments , overspillAssignments , assignmentsByCluster , overspillAssignmentsByCluster );
94138
95139 int [] docIds = null ;
96140 int [] docDeltas = null ;
97141 int [] spillDocIds = null ;
98142 int [] spillDeltas = null ;
143+ final OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors (
144+ floatVectorValues ,
145+ fieldInfo .getVectorDimension (),
146+ new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ())
147+ );
99148 final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
100149 for (int c = 0 ; c < centroidSupplier .size (); c ++) {
101150 float [] centroid = centroidSupplier .centroid (c );
@@ -115,42 +164,13 @@ LongValues buildAndWritePostingsLists(
115164 spillDocIds = new int [spillSize ];
116165 spillDeltas = new int [spillSize ];
117166 }
118- for (int j = 0 ; j < size ; j ++) {
119- docIds [j ] = floatVectorValues .ordToDoc (cluster [j ]);
120- }
121- for (int j = 0 ; j < spillSize ; j ++) {
122- spillDocIds [j ] = floatVectorValues .ordToDoc (overspillCluster [j ]);
123- }
124- final int [] finalDocs = docIds ;
125- final int [] finalSpillDocs = spillDocIds ;
167+ translateOrdsToDocs (cluster , size , overspillCluster , spillSize , docIds , spillDocIds , floatVectorValues ::ordToDoc );
126168 // encode doc deltas
127169 if (size > 0 ) {
128- docDeltas [0 ] = finalDocs [0 ];
129- for (int j = size - 1 ; j > 0 ; j --) {
130- if (finalDocs [j ] < finalDocs [j - 1 ]) {
131- throw new IllegalStateException (
132- "docIds are not sorted: "
133- + finalDocs [j ]
134- + " < "
135- + finalDocs [j - 1 ]
136- );
137- }
138- docDeltas [j ] = finalDocs [j ] - finalDocs [j - 1 ];
139- }
170+ deltaEncode (docIds , size , docDeltas );
140171 }
141172 if (spillSize > 0 ) {
142- spillDeltas [0 ] = finalSpillDocs [0 ];
143- for (int j = spillSize - 1 ; j > 0 ; j --) {
144- if (finalSpillDocs [j ] < finalSpillDocs [j - 1 ]) {
145- throw new IllegalStateException (
146- "Overspill docIds are not sorted: "
147- + finalSpillDocs [j ]
148- + " < "
149- + finalSpillDocs [j - 1 ]
150- );
151- }
152- spillDeltas [j ] = finalSpillDocs [j ] - finalSpillDocs [j - 1 ];
153- }
173+ deltaEncode (spillDocIds , spillSize , spillDeltas );
154174 }
155175 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
156176 postingsOutput .writeInt (size );
@@ -160,25 +180,16 @@ LongValues buildAndWritePostingsLists(
160180 // keeping them in the same file indicates we pull the entire file into cache
161181 postingsOutput .writeGroupVInts (docDeltas , size );
162182 postingsOutput .writeGroupVInts (spillDeltas , spillSize );
163- OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors (
164- floatVectorValues ,
165- fieldInfo .getVectorDimension (),
166- new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ())
167- );
168183 onHeapQuantizedVectors .reset (centroid , size , j -> cluster [j ]);
169184 bulkWriter .writeVectors (onHeapQuantizedVectors );
170185 // write overspill vectors
171- onHeapQuantizedVectors = new OnHeapQuantizedVectors (
172- floatVectorValues ,
173- fieldInfo .getVectorDimension (),
174- new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ())
175- );
176186 onHeapQuantizedVectors .reset (centroid , spillSize , j -> overspillCluster [j ]);
177187 bulkWriter .writeVectors (onHeapQuantizedVectors );
178188 }
179189
180190 if (logger .isDebugEnabled ()) {
181191 printClusterQualityStatistics (assignmentsByCluster );
192+ printClusterQualityStatistics (overspillAssignmentsByCluster );
182193 }
183194
184195 return offsets .build ();
@@ -240,40 +251,22 @@ LongValues buildAndWritePostingsLists(
240251 mergeState .segmentInfo .dir .deleteFile (quantizedVectorsTemp .getName ());
241252 }
242253 }
243- int [] centroidVectorCount = new int [centroidSupplier .size ()];
244- int [] overspillVectorCount = new int [centroidSupplier .size ()];
245- for (int i = 0 ; i < assignments .length ; i ++) {
246- centroidVectorCount [assignments [i ]]++;
247- // if soar assignments are present, count them as well
248- if (overspillAssignments .length > i && overspillAssignments [i ] != -1 ) {
249- overspillVectorCount [overspillAssignments [i ]]++;
250- }
251- }
252-
253254 int [][] assignmentsByCluster = new int [centroidSupplier .size ()][];
254255 int [][] overspillAssignmentsByCluster = new int [centroidSupplier .size ()][];
255- for (int c = 0 ; c < centroidSupplier .size (); c ++) {
256- assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
257- overspillAssignmentsByCluster [c ] = new int [overspillVectorCount [c ]];
258- }
259- Arrays .fill (centroidVectorCount , 0 );
260- Arrays .fill (overspillVectorCount , 0 );
261- for (int i = 0 ; i < assignments .length ; i ++) {
262- int c = assignments [i ];
263- assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
264- // if soar assignments are present, add them to the cluster as well
265- if (overspillAssignments .length > i ) {
266- int s = overspillAssignments [i ];
267- if (s != -1 ) {
268- overspillAssignmentsByCluster [s ][overspillVectorCount [s ]++] = i ;
269- }
270- }
271- }
256+ // pivot the assignments into clusters
257+ pivotAssignments (centroidSupplier .size (), assignments , overspillAssignments , assignmentsByCluster , overspillAssignmentsByCluster );
272258 // now we can read the quantized vectors from the temporary file
273259 try (IndexInput quantizedVectorsInput = mergeState .segmentInfo .dir .openInput (quantizedVectorsTempName , IOContext .DEFAULT )) {
274260 final PackedLongValues .Builder offsets = PackedLongValues .monotonicBuilder (PackedInts .COMPACT );
275261
276- DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (ES91OSQVectorsScorer .BULK_SIZE , postingsOutput );
262+ final DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter .OneBitDiskBBQBulkWriter (
263+ ES91OSQVectorsScorer .BULK_SIZE ,
264+ postingsOutput
265+ );
266+ final OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors (
267+ quantizedVectorsInput ,
268+ fieldInfo .getVectorDimension ()
269+ );
277270 int [] docIds = null ;
278271 int [] docDeltas = null ;
279272 int [] spillDocIds = null ;
@@ -297,26 +290,14 @@ LongValues buildAndWritePostingsLists(
297290 spillDocIds = new int [spillSize ];
298291 spillDeltas = new int [spillSize ];
299292 }
300- for (int j = 0 ; j < size ; j ++) {
301- docIds [j ] = floatVectorValues .ordToDoc (cluster [j ]);
302- }
303- for (int j = 0 ; j < spillSize ; j ++) {
304- spillDocIds [j ] = floatVectorValues .ordToDoc (overspillCluster [j ]);
305- }
306- final int [] finalDocs = docIds ;
307- final int [] finalSpillDocs = spillDocIds ;
293+ // translate ordinals to docIds
294+ translateOrdsToDocs (cluster , size , overspillCluster , spillSize , docIds , spillDocIds , floatVectorValues ::ordToDoc );
308295 // encode doc deltas
309296 if (size > 0 ) {
310- docDeltas [0 ] = finalDocs [0 ];
311- for (int j = size - 1 ; j > 0 ; j --) {
312- docDeltas [j ] = finalDocs [j ] - finalDocs [j - 1 ];
313- }
297+ deltaEncode (docIds , size , docDeltas );
314298 }
315299 if (spillSize > 0 ) {
316- spillDeltas [0 ] = finalSpillDocs [0 ];
317- for (int j = spillSize - 1 ; j > 0 ; j --) {
318- spillDeltas [j ] = finalSpillDocs [j ] - finalSpillDocs [j - 1 ];
319- }
300+ deltaEncode (spillDocIds , spillSize , spillDeltas );
320301 }
321302 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
322303 postingsOutput .writeInt (size );
@@ -327,22 +308,16 @@ LongValues buildAndWritePostingsLists(
327308 postingsOutput .writeGroupVInts (docDeltas , size );
328309 postingsOutput .writeGroupVInts (spillDeltas , spillSize );
329310 // write overspill vectors
330- OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors (
331- quantizedVectorsInput ,
332- fieldInfo .getVectorDimension ()
333- );
311+
334312 offHeapQuantizedVectors .reset (size , false , j -> cluster [j ]);
335313 bulkWriter .writeVectors (offHeapQuantizedVectors );
336- offHeapQuantizedVectors = new OffHeapQuantizedVectors (
337- quantizedVectorsInput ,
338- fieldInfo .getVectorDimension ()
339- );
340314 offHeapQuantizedVectors .reset (spillSize , true , j -> overspillCluster [j ]);
341315 bulkWriter .writeVectors (offHeapQuantizedVectors );
342316 }
343317
344318 if (logger .isDebugEnabled ()) {
345319 printClusterQualityStatistics (assignmentsByCluster );
320+ printClusterQualityStatistics (overspillAssignmentsByCluster );
346321 }
347322 return offsets .build ();
348323 }
@@ -506,39 +481,6 @@ public float[] centroid(int centroidOrdinal) throws IOException {
506481 }
507482 }
508483
509- static class IntSorter extends IntroSorter {
510- int pivot = -1 ;
511- private final int [] arr ;
512- private final IntToIntFunction func ;
513-
514- IntSorter (int [] arr , IntToIntFunction func ) {
515- this .arr = arr ;
516- this .func = func ;
517- }
518-
519- @ Override
520- protected void setPivot (int i ) {
521- pivot = func .apply (arr [i ]);
522- }
523-
524- @ Override
525- protected int comparePivot (int j ) {
526- return Integer .compare (pivot , func .apply (arr [j ]));
527- }
528-
529- @ Override
530- protected int compare (int a , int b ) {
531- return Integer .compare (func .apply (arr [a ]), func .apply (arr [b ]));
532- }
533-
534- @ Override
535- protected void swap (int i , int j ) {
536- final int tmp = arr [i ];
537- arr [i ] = arr [j ];
538- arr [j ] = tmp ;
539- }
540- }
541-
542484 interface QuantizedVectorValues {
543485 int count ();
544486
0 commit comments