2929import java .io .IOException ;
3030import java .nio .ByteBuffer ;
3131import java .nio .ByteOrder ;
32- import java .util .ArrayList ;
3332import java .util .Arrays ;
34- import java .util .List ;
3533
3634/**
3735 * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
@@ -131,7 +129,7 @@ CentroidSupplier createCentroidSupplier(
131129 }
132130
133131 static void writeCentroidsAndPartitions (
134- List < CentroidPartition > centroidPartitions ,
132+ CentroidPartition [] centroidPartitions ,
135133 float [][] centroids ,
136134 FieldInfo fieldInfo ,
137135 float [] globalCentroid ,
@@ -144,22 +142,24 @@ static void writeCentroidsAndPartitions(
144142 // TODO do we want to store these distances as well for future use?
145143 // TODO: sort centroids by global centroid (was doing so previously here)
146144
147- // write the top level partition parent nodes and their pointers to the centroids within the partition
148- // a size of 1 indicates a leaf node that did not have a parent node (orphans)
149- for (CentroidPartition centroidPartition : centroidPartitions ) {
150- System .arraycopy (centroidPartition .centroid (), 0 , centroidScratch , 0 , centroidPartition .centroid ().length );
151- OptimizedScalarQuantizer .QuantizationResult result = osq .scalarQuantize (
152- centroidScratch ,
153- quantizedScratch ,
154- (byte ) 4 ,
155- globalCentroid
156- );
157- for (int i = 0 ; i < quantizedScratch .length ; i ++) {
158- quantized [i ] = (byte ) quantizedScratch [i ];
145+ if (centroidPartitions != null ) {
146+ // write the top level partition parent nodes and their pointers to the centroids within the partition
147+ // a size of 1 indicates a leaf node that did not have a parent node (orphans)
148+ for (CentroidPartition centroidPartition : centroidPartitions ) {
149+ System .arraycopy (centroidPartition .centroid (), 0 , centroidScratch , 0 , centroidPartition .centroid ().length );
150+ OptimizedScalarQuantizer .QuantizationResult result = osq .scalarQuantize (
151+ centroidScratch ,
152+ quantizedScratch ,
153+ (byte ) 4 ,
154+ globalCentroid
155+ );
156+ for (int i = 0 ; i < quantizedScratch .length ; i ++) {
157+ quantized [i ] = (byte ) quantizedScratch [i ];
158+ }
159+ writeQuantizedValue (centroidOutput , quantized , result );
160+ centroidOutput .writeInt (centroidPartition .childOrdinal ());
161+ centroidOutput .writeInt (centroidPartition .size ());
159162 }
160- writeQuantizedValue (centroidOutput , quantized , result );
161- centroidOutput .writeInt (centroidPartition .childOrdinal ());
162- centroidOutput .writeInt (centroidPartition .size ());
163163 }
164164
165165 // write the quantized centroids which will be duplicate for orphans
@@ -242,43 +242,41 @@ CentroidAssignments calculateAndWriteCentroids(
242242 centroidOrds [i ] = i ;
243243 }
244244
245- List <CentroidPartition > centroidPartitions = new ArrayList <>();
245+ CentroidPartition [] centroidPartitions = null ;
246+ int partitionsCount = 0 ;
246247
247248 if (centroids .length > IVFVectorsFormat .DEFAULT_VECTORS_PER_CLUSTER ) {
248- List <float []> centroidsList = Arrays .stream (centroids ).toList ();
249- FloatVectorValues centroidsAsFVV = FloatVectorValues .fromFloats (centroidsList , fieldInfo .getVectorDimension ());
250-
251- HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans (fieldInfo .getVectorDimension ());
252- KMeansResult result = hierarchicalKMeans .cluster (centroidsAsFVV , centroids .length / (int ) Math .sqrt (centroids .length ));
249+ KMeansResult result = clusterParentCentroids (fieldInfo , centroids );
253250 float [][] parentCentroids = result .centroids ();
254251 int [] parentChildAssignments = result .assignments ();
255- // TODO: explore using soar assignments here as well
256- // int[] parentChildSoarAssignments = result.soarAssignments();
252+ // TODO: explore soar assignments here as well
253+
254+ centroidPartitions = new CentroidPartition [parentCentroids .length ];
257255
258256 AssignmentArraySorter sorter = new AssignmentArraySorter (centroids , centroidOrds , parentChildAssignments );
259257 sorter .sort (0 , centroids .length );
260258
261- for (int i = 0 ; i < parentChildAssignments .length ; i ++ ) {
259+ for (int i = 0 ; i < parentChildAssignments .length ;) {
262260 int label = parentChildAssignments [i ];
263261 int centroidCount = 0 ;
262+ int childOffset = i ;
264263 int j = i ;
265264 for (; j < parentChildAssignments .length ; j ++) {
266265 if (parentChildAssignments [j ] != label ) {
267266 break ;
268267 }
269268 centroidCount ++;
270269 }
271- int childOrdinal = i ;
272270 i = j ;
273- centroidPartitions . add ( new CentroidPartition (parentCentroids [label ], childOrdinal , centroidCount ) );
271+ centroidPartitions [ partitionsCount ++] = new CentroidPartition (parentCentroids [label ], childOffset , centroidCount );
274272 }
275273 }
276274
277275 writeCentroidsAndPartitions (centroidPartitions , centroids , fieldInfo , globalCentroid , centroidOutput );
278276
279277 if (logger .isDebugEnabled ()) {
280278 logger .debug ("calculate centroids and assign vectors time ms: {}" , (System .nanoTime () - nanoTime ) / 1000000.0 );
281- logger .debug ("final parent centroid count {}: " , centroidPartitions . size () );
279+ logger .debug ("final parent centroid count {}: " , partitionsCount );
282280 logger .debug ("final centroid count: {}" , centroids .length );
283281 }
284282
@@ -291,7 +289,40 @@ CentroidAssignments calculateAndWriteCentroids(
291289 int [] soarAssignments = kMeansResult .soarAssignments ();
292290
293291 int [][] assignmentsByCluster = buildCentroidAssignments (centroids .length , assignments , soarAssignments , centroidOrdsToIdx );
294- return new CentroidAssignments (centroidPartitions .size (), centroids , assignmentsByCluster );
292+ return new CentroidAssignments (partitionsCount , centroids , assignmentsByCluster );
293+ }
294+
295+ private KMeansResult clusterParentCentroids (FieldInfo fieldInfo , float [][] centroids ) throws IOException {
296+ FloatVectorValues centroidsAsFVV = new FloatVectorValues () {
297+ @ Override
298+ public int size () {
299+ return centroids .length ;
300+ }
301+
302+ @ Override
303+ public int dimension () {
304+ return fieldInfo .getVectorDimension ();
305+ }
306+
307+ @ Override
308+ public float [] vectorValue (int targetOrd ) {
309+ return centroids [targetOrd ];
310+ }
311+
312+ @ Override
313+ public FloatVectorValues copy () {
314+ return this ;
315+ }
316+
317+ @ Override
318+ public DocIndexIterator iterator () {
319+ return createDenseIterator ();
320+ }
321+ };
322+
323+ HierarchicalKMeans hierarchicalKMeans = new HierarchicalKMeans (fieldInfo .getVectorDimension ());
324+ KMeansResult result = hierarchicalKMeans .cluster (centroidsAsFVV , centroids .length / (int ) Math .sqrt (centroids .length ));
325+ return result ;
295326 }
296327
297328 static int [][] buildCentroidAssignments (int centroidCount , int [] assignments , int [] soarAssignments , IntIntMap centroidOrds ) {
0 commit comments