@@ -118,16 +118,8 @@ private static void printClusterQualityStatistics(int[][] clusters) {
118118 }
119119
120120 @ Override
121- CentroidSupplier createCentroidSupplier (
122- IndexInput centroidsInput ,
123- int numParentCentroids ,
124- int numCentroids ,
125- int numClusters ,
126- FieldInfo fieldInfo ,
127- float [] globalCentroid ,
128- IntIntMap clusterToCentroidMap
129- ) {
130- return new OffHeapCentroidSupplier (centroidsInput , numParentCentroids , numCentroids , numClusters , fieldInfo , clusterToCentroidMap );
121+ CentroidSupplier createCentroidSupplier (IndexInput centroidsInput , int numClusters , FieldInfo fieldInfo ) {
122+ return new OffHeapCentroidSupplier (centroidsInput , numClusters , fieldInfo );
131123 }
132124
133125 private static void writeQuantizedCentroid (
@@ -181,7 +173,7 @@ static IntIntMap writePartitionsAndCentroids(
181173 centroidOutput
182174 );
183175 // TODO: put at the end of the parents region
184- centroidOutput .writeInt (centroidPartition .childOrdinal ());
176+ centroidOutput .writeInt (centroidPartition .childOffset ());
185177 centroidOutput .writeInt (centroidPartition .size ());
186178 }
187179 }
@@ -252,7 +244,14 @@ CentroidAssignments calculateAndWriteCentroids(
252244 return calculateAndWriteCentroids (fieldInfo , floatVectorValues , centroidOutput , globalCentroid );
253245 }
254246
255- record CentroidPartition (float [] centroid , int childOrdinal , int size , int [] assignments ) {}
247+ /**
248+ *
249+ * @param centroid the parent centroid of some set of children
250+ * @param childOffset the offset of the first child within the partition defined by this parent
251+ * @param size the number of children in this partition
252+ * @param assignments the set of centroid ordinals (potentially duplicative) of child centroids that belong to this parent
253+ */
254+ record CentroidPartition (float [] centroid , int childOffset , int size , int [] assignments ) {}
256255
257256 /**
258257 * Calculate the centroids for the given field and write them to the given centroid output.
@@ -426,30 +425,16 @@ static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, Opt
426425
427426 static class OffHeapCentroidSupplier implements CentroidSupplier {
428427 private final IndexInput centroidsInput ;
429- private final int numClusters ;
430428 private final int dimension ;
431429 private final float [] scratch ;
432- private final long rawCentroidOffset ;
433430 private int currOrd = -1 ;
434- private final IntIntMap clusterToCentroidMap ;
435-
436- OffHeapCentroidSupplier (
437- IndexInput centroidsInput ,
438- int numParentCentroids ,
439- int numCentroids ,
440- int numClusters ,
441- FieldInfo info ,
442- IntIntMap clusterToCentroidMap
443- ) {
431+ private final int numClusters ;
432+
433+ OffHeapCentroidSupplier (IndexInput centroidsInput , int numClusters , FieldInfo info ) {
444434 this .centroidsInput = centroidsInput ;
445- this .numClusters = numClusters ;
446435 this .dimension = info .getVectorDimension ();
447436 this .scratch = new float [dimension ];
448- long quantizedVectorByteSize = dimension + 3 * Float .BYTES + Short .BYTES ;
449- long quantizedVectorNodeByteSize = quantizedVectorByteSize + Integer .BYTES ;
450- long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer .BYTES ;
451- this .rawCentroidOffset = numParentCentroids * parentNodeByteSize + numCentroids * quantizedVectorNodeByteSize ;
452- this .clusterToCentroidMap = clusterToCentroidMap ;
437+ this .numClusters = numClusters ;
453438 }
454439
455440 @ Override
@@ -458,19 +443,13 @@ public int size() {
458443 }
459444
460445 @ Override
461- public float [] centroid (int clusterOrdinal ) throws IOException {
462- if (clusterOrdinal == currOrd ) {
446+ public float [] centroid (int centroidOrdinal ) throws IOException {
447+ if (centroidOrdinal == currOrd ) {
463448 return scratch ;
464449 }
465- int centroidOrdinal ;
466- if (clusterToCentroidMap != null ) {
467- centroidOrdinal = clusterToCentroidMap .get (clusterOrdinal );
468- } else {
469- centroidOrdinal = clusterOrdinal ;
470- }
471- centroidsInput .seek (rawCentroidOffset + (long ) centroidOrdinal * dimension * Float .BYTES );
450+ centroidsInput .seek ((long ) centroidOrdinal * dimension * Float .BYTES );
472451 centroidsInput .readFloats (scratch , 0 , dimension );
473- this .currOrd = clusterOrdinal ;
452+ this .currOrd = centroidOrdinal ;
474453 return scratch ;
475454 }
476455 }
0 commit comments