99
1010package org .elasticsearch .index .codec .vectors ;
1111
12+ import com .carrotsearch .hppc .IntIntHashMap ;
13+ import com .carrotsearch .hppc .IntIntMap ;
14+
1215import org .apache .lucene .codecs .hnsw .FlatVectorsWriter ;
1316import org .apache .lucene .index .FieldInfo ;
1417import org .apache .lucene .index .FloatVectorValues ;
3033import java .util .Arrays ;
3134import java .util .List ;
3235
33- import static org .elasticsearch .index .codec .vectors .IVFVectorsFormat .DEFAULT_VECTORS_PER_CLUSTER ;
34-
3536/**
3637 * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
3738 * partition the vector space, and then stores the centroids and posting list in a sequential
@@ -153,7 +154,10 @@ static void writeCentroidsAndPartitions(
153154 (byte ) 4 ,
154155 globalCentroid
155156 );
156- writeQuantizedValue (centroidOutput , quantizedScratch , result );
157+ for (int i = 0 ; i < quantizedScratch .length ; i ++) {
158+ quantized [i ] = (byte ) quantizedScratch [i ];
159+ }
160+ writeQuantizedValue (centroidOutput , quantized , result );
157161 centroidOutput .writeInt (centroidPartition .childOrdinal ());
158162 centroidOutput .writeInt (centroidPartition .size ());
159163 }
@@ -252,40 +256,32 @@ CentroidAssignments calculateAndWriteCentroids(
252256
253257 List <CentroidPartition > centroidPartitions = new ArrayList <>();
254258
255- // TODO: make this configurable
256- if (centroids .length > DEFAULT_VECTORS_PER_CLUSTER ) {
257- // TODO: sort by global centroids as well
258- // TODO: have this take a function instead of just an int[] for sorting
259- AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , kMeansResult .parentLayer ());
260- sorter .sort (0 , centroids .length );
261-
262- for (int i = 0 ; i < kMeansResult .parentLayer ().length ;) {
263- // for any layer that was not partitioned we treat it duplicatively as a parent and child
264- if (kMeansResult .parentLayer ()[i ] == -1 ) {
265- centroidPartitions .add (new CentroidPartition (centroids [i ], i , 1 ));
266- i ++;
267- } else {
268- int label = kMeansResult .parentLayer ()[i ];
269- int centroidCount = 0 ;
270- float [] parentPartitionCentroid = new float [fieldInfo .getVectorDimension ()];
271- int j = i ;
272- for (; j < kMeansResult .parentLayer ().length ; j ++) {
273- if (kMeansResult .parentLayer ()[j ] != label ) {
274- break ;
275- }
276- for (int k = 0 ; k < parentPartitionCentroid .length ; k ++) {
277- parentPartitionCentroid [k ] += centroids [i ][k ];
278- }
279- centroidCount ++;
280- }
281- int childOrdinal = i ;
282- i = j ;
283- for (int d = 0 ; d < parentPartitionCentroid .length ; d ++) {
284- parentPartitionCentroid [d ] /= centroidCount ;
285- }
286- centroidPartitions .add (new CentroidPartition (parentPartitionCentroid , childOrdinal , centroidCount ));
259+ List <float []> centroidsList = Arrays .stream (centroids ).toList ();
260+ FloatVectorValues centroidsAsFVV = FloatVectorValues .fromFloats (centroidsList , fieldInfo .getVectorDimension ());
261+
262+ HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans (fieldInfo .getVectorDimension ());
263+ KMeansResult result = hierarchicalKMeans .cluster (centroidsAsFVV , centroids .length / (int ) Math .sqrt (centroids .length ));
264+ float [][] parentCentroids = result .centroids ();
265+ int [] parentChildAssignments = result .assignments ();
266+ // TODO: explore using soar assignments here as well
267+ //int[] parentChildSoarAssignments = result.soarAssignments();
268+
269+ AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , parentChildAssignments );
270+ sorter .sort (0 , centroids .length );
271+
272+ for (int i = 0 ; i < parentChildAssignments .length ; i ++) {
273+ int label = parentChildAssignments [i ];
274+ int centroidCount = 0 ;
275+ int j = i ;
276+ for (; j < parentChildAssignments .length ; j ++) {
277+ if (parentChildAssignments [j ] != label ) {
278+ break ;
287279 }
280+ centroidCount ++;
288281 }
282+ int childOrdinal = i ;
283+ i = j ;
284+ centroidPartitions .add (new CentroidPartition (parentCentroids [label ], childOrdinal , centroidCount ));
289285 }
290286
291287 writeCentroidsAndPartitions (centroidPartitions , centroids , fieldInfo , globalCentroid , centroidOutput );
@@ -298,7 +294,11 @@ CentroidAssignments calculateAndWriteCentroids(
298294 logger .debug ("final centroid count: {}" , centroids .length );
299295 }
300296
301- int [][] assignmentsByCluster = mapAssignmentsByCluster (centroids .length , assignments , soarAssignments , centroidOrds );
297+ IntIntMap centroidOrdsToIdx = new IntIntHashMap (centroidOrds .length );
298+ for (int i = 0 ; i < centroidOrds .length ; i ++) {
299+ centroidOrdsToIdx .put (centroidOrds [i ], i );
300+ }
301+ int [][] assignmentsByCluster = mapAssignmentsByCluster (centroids .length , assignments , soarAssignments , centroidOrdsToIdx );
302302
303303 if (cacheCentroids ) {
304304 return new CentroidAssignments (centroidPartitions .size (), centroids , assignmentsByCluster );
@@ -307,26 +307,14 @@ CentroidAssignments calculateAndWriteCentroids(
307307 }
308308 }
309309
310- // FIXME: clean this up
311- static int [][] mapAssignmentsByCluster (int centroidCount , int [] assignments , int [] soarAssignments , int [] centroidOrds ) {
310+ static int [][] mapAssignmentsByCluster (int centroidCount , int [] assignments , int [] soarAssignments , IntIntMap centroidOrds ) {
312311 int [] centroidVectorCount = new int [centroidCount ];
313312 for (int i = 0 ; i < assignments .length ; i ++) {
314- int c = -1 ;
315- // FIXME: create a reverse mapping prior to this step? .. expensive
316- for (int j = 0 ; j < centroidOrds .length ; j ++) {
317- if (assignments [i ] == centroidOrds [j ]) {
318- c = j ;
319- }
320- }
313+ int c = centroidOrds .get (assignments [i ]);
321314 centroidVectorCount [c ]++;
322315 // if soar assignments are present, count them as well
323316 if (soarAssignments .length > i && soarAssignments [i ] != -1 ) {
324- int s = -1 ;
325- for (int j = 0 ; j < centroidOrds .length ; j ++) {
326- if (soarAssignments [i ] == centroidOrds [j ]) {
327- s = j ;
328- }
329- }
317+ int s = centroidOrds .get (soarAssignments [i ]);
330318 centroidVectorCount [s ]++;
331319 }
332320 }
@@ -338,21 +326,11 @@ static int[][] mapAssignmentsByCluster(int centroidCount, int[] assignments, int
338326 Arrays .fill (centroidVectorCount , 0 );
339327
340328 for (int i = 0 ; i < assignments .length ; i ++) {
341- int c = -1 ;
342- for (int j = 0 ; j < centroidOrds .length ; j ++) {
343- if (assignments [i ] == centroidOrds [j ]) {
344- c = j ;
345- }
346- }
329+ int c = centroidOrds .get (assignments [i ]);
347330 assignmentsByCluster [c ][centroidVectorCount [c ]++] = i ;
348331 // if soar assignments are present, add them to the cluster as well
349332 if (soarAssignments .length > i ) {
350- int s = -1 ;
351- for (int j = 0 ; j < centroidOrds .length ; j ++) {
352- if (soarAssignments [i ] == centroidOrds [j ]) {
353- s = j ;
354- }
355- }
333+ int s = centroidOrds .getOrDefault (soarAssignments [i ], -1 );
356334 if (s != -1 ) {
357335 assignmentsByCluster [s ][centroidVectorCount [s ]++] = i ;
358336 }
0 commit comments