2828import java .nio .ByteBuffer ;
2929import java .nio .ByteOrder ;
3030import java .util .ArrayList ;
31- import java .util .Arrays ;
3231import java .util .List ;
3332
3433import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .INDEX_BITS ;
3534import static org .elasticsearch .index .codec .vectors .BQVectorUtils .discretize ;
3635import static org .elasticsearch .index .codec .vectors .BQVectorUtils .packAsBinary ;
36+ import static org .elasticsearch .index .codec .vectors .IVFVectorsFormat .DEFAULT_VECTORS_PER_CLUSTER ;
3737
3838/**
3939 * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
@@ -167,14 +167,23 @@ private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput,
167167 }
168168
169169 @ Override
170- CentroidSupplier createCentroidSupplier (IndexInput centroidsInput , int numParentCentroids ,
171- int numCentroids , FieldInfo fieldInfo , float [] globalCentroid ) {
170+ CentroidSupplier createCentroidSupplier (
171+ IndexInput centroidsInput ,
172+ int numParentCentroids ,
173+ int numCentroids ,
174+ FieldInfo fieldInfo ,
175+ float [] globalCentroid
176+ ) {
172177 return new OffHeapCentroidSupplier (centroidsInput , numParentCentroids , numCentroids , fieldInfo );
173178 }
174179
175- static void writeCentroidsAndPartitions (List <CentroidPartition > centroidPartitions , float [][] centroids ,
176- FieldInfo fieldInfo , float [] globalCentroid , IndexOutput centroidOutput )
177- throws IOException {
180+ static void writeCentroidsAndPartitions (
181+ List <CentroidPartition > centroidPartitions ,
182+ float [][] centroids ,
183+ FieldInfo fieldInfo ,
184+ float [] globalCentroid ,
185+ IndexOutput centroidOutput
186+ ) throws IOException {
178187 final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
179188 byte [] quantizedScratch = new byte [fieldInfo .getVectorDimension ()];
180189 float [] centroidScratch = new float [fieldInfo .getVectorDimension ()];
@@ -208,7 +217,7 @@ static void writeCentroidsAndPartitions(List<CentroidPartition> centroidPartitio
208217 writeQuantizedValue (centroidOutput , quantizedScratch , result );
209218 }
210219
211- //write the raw float vectors so we can quantize the query vector relative to the centroid on read
220+ // write the raw float vectors so we can quantize the query vector relative to the centroid on read
212221 final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
213222 for (float [] centroid : centroids ) {
214223 buffer .asFloatBuffer ().put (centroid );
@@ -281,41 +290,44 @@ CentroidAssignments calculateAndWriteCentroids(
281290 // TODO: sort while constructing the hkmeans structure
282291 // we do this so we don't have to sort the assignments which is much more expensive
283292 int [] centroidOrds = new int [centroids .length ];
284- for (int i = 0 ; i < centroidOrds .length ; i ++) {
293+ for (int i = 0 ; i < centroidOrds .length ; i ++) {
285294 centroidOrds [i ] = i ;
286295 }
287296
288- // TODO: sort by global centroids as well
289- // TODO: have this take a function instead of just an int[] for sorting
290- AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , kMeansResult .parentLayer ());
291- sorter .sort (0 , centroids .length );
292-
293297 List <CentroidPartition > centroidPartitions = new ArrayList <>();
294- for (int i = 0 ; i < kMeansResult .parentLayer ().length ;) {
295- // for any layer that was not partitioned we treat it duplicatively as a parent and child
296- if (kMeansResult .parentLayer ()[i ] == -1 ) {
297- centroidPartitions .add (new CentroidPartition (centroids [i ], i , 1 ));
298- i ++;
299- } else {
300- int label = kMeansResult .parentLayer ()[i ];
301- int totalCentroids = 0 ;
302- float [] parentPartitionCentroid = new float [fieldInfo .getVectorDimension ()];
303- int j = i ;
304- for (; j < kMeansResult .parentLayer ().length ; j ++) {
305- if (kMeansResult .parentLayer ()[j ] != label ) {
306- break ;
298+
299+ if (centroids .length > DEFAULT_VECTORS_PER_CLUSTER ) {
300+ // TODO: sort by global centroids as well
301+ // TODO: have this take a function instead of just an int[] for sorting
302+ AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , kMeansResult .parentLayer ());
303+ sorter .sort (0 , centroids .length );
304+
305+ for (int i = 0 ; i < kMeansResult .parentLayer ().length ;) {
306+ // for any layer that was not partitioned we treat it duplicatively as a parent and child
307+ if (kMeansResult .parentLayer ()[i ] == -1 ) {
308+ centroidPartitions .add (new CentroidPartition (centroids [i ], i , 1 ));
309+ i ++;
310+ } else {
311+ int label = kMeansResult .parentLayer ()[i ];
312+ int totalCentroids = 0 ;
313+ float [] parentPartitionCentroid = new float [fieldInfo .getVectorDimension ()];
314+ int j = i ;
315+ for (; j < kMeansResult .parentLayer ().length ; j ++) {
316+ if (kMeansResult .parentLayer ()[j ] != label ) {
317+ break ;
318+ }
319+ for (int k = 0 ; k < parentPartitionCentroid .length ; k ++) {
320+ parentPartitionCentroid [k ] += centroids [i ][k ];
321+ }
322+ totalCentroids ++;
307323 }
308- for (int k = 0 ; k < parentPartitionCentroid .length ; k ++) {
309- parentPartitionCentroid [k ] += centroids [i ][k ];
324+ int childOrdinal = i ;
325+ i = j ;
326+ for (int d = 0 ; d < parentPartitionCentroid .length ; d ++) {
327+ parentPartitionCentroid [d ] /= totalCentroids ;
310328 }
311- totalCentroids ++;
312- }
313- int childOrdinal = i ;
314- i = j ;
315- for (int d = 0 ; d < parentPartitionCentroid .length ; d ++) {
316- parentPartitionCentroid [d ] /= totalCentroids ;
329+ centroidPartitions .add (new CentroidPartition (parentPartitionCentroid , childOrdinal , totalCentroids ));
317330 }
318- centroidPartitions .add (new CentroidPartition (parentPartitionCentroid , childOrdinal , totalCentroids ));
319331 }
320332 }
321333
@@ -330,7 +342,7 @@ CentroidAssignments calculateAndWriteCentroids(
330342 for (int c = 0 ; c < assignmentsByCluster .length ; c ++) {
331343 IntArrayList cluster = new IntArrayList (vectorPerCluster );
332344 for (int j = 0 ; j < assignments .length ; j ++) {
333- if (assignments [j ] == -1 ) {
345+ if (assignments [j ] == -1 ) {
334346 continue ;
335347 }
336348 if (assignments [j ] == centroidOrds [c ]) {
@@ -339,7 +351,7 @@ CentroidAssignments calculateAndWriteCentroids(
339351 }
340352
341353 for (int j = 0 ; j < soarAssignments .length ; j ++) {
342- if (soarAssignments [j ] == -1 ) {
354+ if (soarAssignments [j ] == -1 ) {
343355 continue ;
344356 }
345357 if (soarAssignments [j ] == centroidOrds [c ]) {
0 commit comments