@@ -87,17 +87,17 @@ private boolean stepLloyd(
8787
8888 for (int i = 0 ; i < sampleSize ; i ++) {
8989 float [] vector = vectors .vectorValue (i );
90- int [] neighborOffsets = null ;
91- int centroidIdx = - 1 ;
90+ final int assignment = assignments [ i ] ;
91+ final int bestCentroidOffset ;
9292 if (neighborhoods != null ) {
93- neighborOffsets = neighborhoods .get (assignments [i ]);
94- centroidIdx = assignments [i ];
93+ bestCentroidOffset = getBestCentroidFromNeighbours (centroids , vector , assignment , neighborhoods .get (assignment ));
94+ } else {
95+ bestCentroidOffset = getBestCentroid (centroids , vector );
9596 }
96- int bestCentroidOffset = getBestCentroidOffset ( centroids , vector , centroidIdx , neighborOffsets );
97- if ( assignments [i ] ! = bestCentroidOffset ) {
97+ if ( assignment != bestCentroidOffset ) {
98+ assignments [i ] = bestCentroidOffset ;
9899 changed = true ;
99100 }
100- assignments [i ] = bestCentroidOffset ;
101101 centroidCounts [bestCentroidOffset ]++;
102102 for (int d = 0 ; d < dim ; d ++) {
103103 nextCentroids [bestCentroidOffset ][d ] += vector [d ];
@@ -116,23 +116,28 @@ private boolean stepLloyd(
116116 return changed ;
117117 }
118118
119- int getBestCentroidOffset (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
119+ int getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
120120 int bestCentroidOffset = centroidIdx ;
121- float minDsq ;
122- if (centroidIdx > 0 && centroidIdx < centroids .length ) {
123- minDsq = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
124- } else {
125- minDsq = Float .MAX_VALUE ;
121+ assert centroidIdx >= 0 && centroidIdx < centroids .length ;
122+ float minDsq = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
123+ for (int offset : centroidOffsets ) {
124+ float dsq = VectorUtil .squareDistance (vector , centroids [offset ]);
125+ if (dsq < minDsq ) {
126+ minDsq = dsq ;
127+ bestCentroidOffset = offset ;
128+ }
126129 }
130+ return bestCentroidOffset ;
131+ }
127132
128- int k = 0 ;
129- for ( int j = 0 ; j < centroids . length ; j ++) {
130- if ( centroidOffsets == null || j == centroidOffsets [ k ]) {
131- float dsq = VectorUtil . squareDistance ( vector , centroids [ j ]);
132- if ( dsq < minDsq ) {
133- minDsq = dsq ;
134- bestCentroidOffset = j ;
135- }
133+ int getBestCentroid ( float [][] centroids , float [] vector ) {
134+ int bestCentroidOffset = 0 ;
135+ float minDsq = Float . MAX_VALUE ;
136+ for ( int i = 0 ; i < centroids . length ; i ++) {
137+ float dsq = VectorUtil . squareDistance ( vector , centroids [ i ]);
138+ if ( dsq < minDsq ) {
139+ minDsq = dsq ;
140+ bestCentroidOffset = i ;
136141 }
137142 }
138143 return bestCentroidOffset ;
@@ -271,7 +276,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
271276 return ;
272277 }
273278
274- int [] assignments = new int [n ];
279+ int [] assignments = kMeansIntermediate .assignments ();
280+ assert assignments .length == n ;
275281 float [][] nextCentroids = new float [centroids .length ][vectors .dimension ()];
276282 for (int i = 0 ; i < maxIterations ; i ++) {
277283 if (stepLloyd (vectors , centroids , nextCentroids , assignments , sampleSize , neighborhoods ) == false ) {
@@ -291,7 +297,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, L
291297 * @param maxIterations the max iterations to shift centroids
292298 */
293299 public static void cluster (FloatVectorValues vectors , float [][] centroids , int sampleSize , int maxIterations ) throws IOException {
294- KMeansIntermediate kMeansIntermediate = new KMeansIntermediate (centroids );
300+ KMeansIntermediate kMeansIntermediate = new KMeansIntermediate (centroids , new int [ vectors . size ()], vectors :: ordToDoc );
295301 KMeansLocal kMeans = new KMeansLocal (sampleSize , maxIterations );
296302 kMeans .cluster (vectors , kMeansIntermediate );
297303 }
0 commit comments