1414import org .apache .lucene .index .FloatVectorValues ;
1515import org .apache .lucene .index .MergeState ;
1616import org .apache .lucene .index .SegmentWriteState ;
17- import org .apache .lucene .internal .hppc .IntArrayList ;
1817import org .apache .lucene .store .IndexInput ;
1918import org .apache .lucene .store .IndexOutput ;
2019import org .apache .lucene .util .VectorUtil ;
2726import java .io .IOException ;
2827import java .nio .ByteBuffer ;
2928import java .nio .ByteOrder ;
29+ import java .util .Arrays ;
3030
3131import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .INDEX_BITS ;
3232import static org .elasticsearch .index .codec .vectors .BQVectorUtils .discretize ;
@@ -53,7 +53,7 @@ long[] buildAndWritePostingsLists(
5353 CentroidSupplier centroidSupplier ,
5454 FloatVectorValues floatVectorValues ,
5555 IndexOutput postingsOutput ,
56- IntArrayList [] assignmentsByCluster
56+ int [] [] assignmentsByCluster
5757 ) throws IOException {
5858 // write the posting lists
5959 final long [] offsets = new long [centroidSupplier .size ()];
@@ -65,16 +65,16 @@ long[] buildAndWritePostingsLists(
6565 float [] centroid = centroidSupplier .centroid (c );
6666 binarizedByteVectorValues .centroid = centroid ;
6767 // TODO: add back in sorting vectors by distance to centroid
68- IntArrayList cluster = assignmentsByCluster [c ];
68+ int [] cluster = assignmentsByCluster [c ];
6969 // TODO align???
7070 offsets [c ] = postingsOutput .getFilePointer ();
71- int size = cluster .size () ;
71+ int size = cluster .length ;
7272 postingsOutput .writeVInt (size );
7373 postingsOutput .writeInt (Float .floatToIntBits (VectorUtil .dotProduct (centroid , centroid )));
7474 // TODO we might want to consider putting the docIds in a separate file
7575 // to aid with only having to fetch vectors from slower storage when they are required
7676 // keeping them in the same file indicates we pull the entire file into cache
77- docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster . get ( j ) ), size , postingsOutput );
77+ docIdsWriter .writeDocIds (j -> floatVectorValues .ordToDoc (cluster [ j ] ), size , postingsOutput );
7878 writePostingList (cluster , postingsOutput , binarizedByteVectorValues );
7979 }
8080
@@ -85,23 +85,23 @@ long[] buildAndWritePostingsLists(
8585 return offsets ;
8686 }
8787
88- private static void printClusterQualityStatistics (IntArrayList [] clusters ) {
88+ private static void printClusterQualityStatistics (int [] [] clusters ) {
8989 float min = Float .MAX_VALUE ;
9090 float max = Float .MIN_VALUE ;
9191 float mean = 0 ;
9292 float m2 = 0 ;
9393 // iteratively compute the variance & mean
9494 int count = 0 ;
95- for (IntArrayList cluster : clusters ) {
95+ for (int [] cluster : clusters ) {
9696 count += 1 ;
9797 if (cluster == null ) {
9898 continue ;
9999 }
100- float delta = cluster .size () - mean ;
100+ float delta = cluster .length - mean ;
101101 mean += delta / count ;
102- m2 += delta * (cluster .size () - mean );
103- min = Math .min (min , cluster .size () );
104- max = Math .max (max , cluster .size () );
102+ m2 += delta * (cluster .length - mean );
103+ min = Math .min (min , cluster .length );
104+ max = Math .max (max , cluster .length );
105105 }
106106 float variance = m2 / (clusters .length - 1 );
107107 logger .debug (
@@ -115,16 +115,16 @@ private static void printClusterQualityStatistics(IntArrayList[] clusters) {
115115 );
116116 }
117117
118- private void writePostingList (IntArrayList cluster , IndexOutput postingsOutput , BinarizedFloatVectorValues binarizedByteVectorValues )
118+ private void writePostingList (int [] cluster , IndexOutput postingsOutput , BinarizedFloatVectorValues binarizedByteVectorValues )
119119 throws IOException {
120- int limit = cluster .size () - ES91OSQVectorsScorer .BULK_SIZE + 1 ;
120+ int limit = cluster .length - ES91OSQVectorsScorer .BULK_SIZE + 1 ;
121121 int cidx = 0 ;
122122 OptimizedScalarQuantizer .QuantizationResult [] corrections =
123123 new OptimizedScalarQuantizer .QuantizationResult [ES91OSQVectorsScorer .BULK_SIZE ];
124124 // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
125125 for (; cidx < limit ; cidx += ES91OSQVectorsScorer .BULK_SIZE ) {
126126 for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
127- int ord = cluster . get ( cidx + j ) ;
127+ int ord = cluster [ cidx + j ] ;
128128 byte [] binaryValue = binarizedByteVectorValues .vectorValue (ord );
129129 // write vector
130130 postingsOutput .writeBytes (binaryValue , 0 , binaryValue .length );
@@ -147,8 +147,8 @@ private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput,
147147 }
148148 }
149149 // write tail
150- for (; cidx < cluster .size () ; cidx ++) {
151- int ord = cluster . get ( cidx ) ;
150+ for (; cidx < cluster .length ; cidx ++) {
151+ int ord = cluster [ cidx ] ;
152152 // write vector
153153 byte [] binaryValue = binarizedByteVectorValues .vectorValue (ord );
154154 OptimizedScalarQuantizer .QuantizationResult correction = binarizedByteVectorValues .getCorrectiveTerms (ord );
@@ -261,23 +261,31 @@ CentroidAssignments calculateAndWriteCentroids(
261261 logger .debug ("final centroid count: {}" , centroids .length );
262262 }
263263
264- IntArrayList [] assignmentsByCluster = new IntArrayList [centroids .length ];
265- for (int c = 0 ; c < centroids .length ; c ++) {
266- IntArrayList cluster = new IntArrayList (vectorPerCluster );
267- for (int j = 0 ; j < assignments .length ; j ++) {
268- if (assignments [j ] == c ) {
269- cluster .add (j );
270- }
264+ int [] centroidVectorCount = new int [centroids .length ];
265+ for (int i = 0 ; i < assignments .length ; i ++) {
266+ centroidVectorCount [assignments [i ]]++;
267+ // if soar assignments are present, count them as well
268+ if (soarAssignments .length > i && soarAssignments [i ] != -1 ) {
269+ centroidVectorCount [soarAssignments [i ]]++;
271270 }
271+ }
272272
273- for (int j = 0 ; j < soarAssignments .length ; j ++) {
274- if (soarAssignments [j ] == c ) {
275- cluster .add (j );
273+ int [][] assignmentsByCluster = new int [centroids .length ][];
274+ for (int c = 0 ; c < centroids .length ; c ++) {
275+ assignmentsByCluster [c ] = new int [centroidVectorCount [c ]];
276+ }
277+ Arrays .fill (centroidVectorCount , 0 );
278+
279+ for (int i = 0 ; i < assignments .length ; i ++) {
280+ int c = assignments [i ];
281+ assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
282+ // if soar assignments are present, add them to the cluster as well
283+ if (soarAssignments .length > i ) {
284+ int s = soarAssignments [i ];
285+ if (s != -1 ) {
286+ assignmentsByCluster [s ][centroidVectorCount [s ]++] = i ;
276287 }
277288 }
278-
279- cluster .trimToSize ();
280- assignmentsByCluster [c ] = cluster ;
281289 }
282290
283291 if (cacheCentroids ) {
0 commit comments