@@ -244,32 +244,34 @@ CentroidAssignments calculateAndWriteCentroids(
244244
245245 List <CentroidPartition > centroidPartitions = new ArrayList <>();
246246
247- List <float []> centroidsList = Arrays .stream (centroids ).toList ();
248- FloatVectorValues centroidsAsFVV = FloatVectorValues .fromFloats (centroidsList , fieldInfo .getVectorDimension ());
249-
250- HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans (fieldInfo .getVectorDimension ());
251- KMeansResult result = hierarchicalKMeans .cluster (centroidsAsFVV , centroids .length / (int ) Math .sqrt (centroids .length ));
252- float [][] parentCentroids = result .centroids ();
253- int [] parentChildAssignments = result .assignments ();
254- // TODO: explore using soar assignments here as well
255- // int[] parentChildSoarAssignments = result.soarAssignments();
256-
257- AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , parentChildAssignments );
258- sorter .sort (0 , centroids .length );
259-
260- for (int i = 0 ; i < parentChildAssignments .length ; i ++) {
261- int label = parentChildAssignments [i ];
262- int centroidCount = 0 ;
263- int j = i ;
264- for (; j < parentChildAssignments .length ; j ++) {
265- if (parentChildAssignments [j ] != label ) {
266- break ;
247+ if (centroids .length > IVFVectorsFormat .DEFAULT_VECTORS_PER_CLUSTER ) {
248+ List <float []> centroidsList = Arrays .stream (centroids ).toList ();
249+ FloatVectorValues centroidsAsFVV = FloatVectorValues .fromFloats (centroidsList , fieldInfo .getVectorDimension ());
250+
251+ HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans (fieldInfo .getVectorDimension ());
252+ KMeansResult result = hierarchicalKMeans .cluster (centroidsAsFVV , centroids .length / (int ) Math .sqrt (centroids .length ));
253+ float [][] parentCentroids = result .centroids ();
254+ int [] parentChildAssignments = result .assignments ();
255+ // TODO: explore using soar assignments here as well
256+ // int[] parentChildSoarAssignments = result.soarAssignments();
257+
258+ AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , parentChildAssignments );
259+ sorter .sort (0 , centroids .length );
260+
261+ for (int i = 0 ; i < parentChildAssignments .length ; i ++) {
262+ int label = parentChildAssignments [i ];
263+ int centroidCount = 0 ;
264+ int j = i ;
265+ for (; j < parentChildAssignments .length ; j ++) {
266+ if (parentChildAssignments [j ] != label ) {
267+ break ;
268+ }
269+ centroidCount ++;
267270 }
268- centroidCount ++;
271+ int childOrdinal = i ;
272+ i = j ;
273+ centroidPartitions .add (new CentroidPartition (parentCentroids [label ], childOrdinal , centroidCount ));
269274 }
270- int childOrdinal = i ;
271- i = j ;
272- centroidPartitions .add (new CentroidPartition (parentCentroids [label ], childOrdinal , centroidCount ));
273275 }
274276
275277 writeCentroidsAndPartitions (centroidPartitions , centroids , fieldInfo , globalCentroid , centroidOutput );
0 commit comments