1010package org .elasticsearch .index .codec .vectors .cluster ;
1111
1212import org .apache .lucene .index .FloatVectorValues ;
13- import org .apache .lucene .util .VectorUtil ;
1413
1514import java .io .IOException ;
1615
@@ -85,42 +84,16 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
8584
8685 // TODO: instead of creating a sub-cluster assignments reuse the parent array each time
8786 int [] assignments = new int [vectors .size ()];
88-
8987 KMeansLocal kmeans = new KMeansLocal (m , maxIterations );
9088 float [][] centroids = KMeansLocal .pickInitialCentroids (vectors , k );
91- KMeansIntermediate kMeansIntermediate = new KMeansIntermediate (centroids );
89+ KMeansIntermediate kMeansIntermediate = new KMeansIntermediate (centroids , assignments , vectors :: ordToDoc );
9290 kmeans .cluster (vectors , kMeansIntermediate );
9391
9492 // TODO: consider adding cluster size counts to the kmeans algo
9593 // handle assignment here so we can track distance and cluster size
9694 int [] centroidVectorCount = new int [centroids .length ];
97- float [][] nextCentroids = new float [centroids .length ][dimension ];
98- for (int i = 0 ; i < vectors .size (); i ++) {
99- float smallest = Float .MAX_VALUE ;
100- int centroidIdx = -1 ;
101- float [] vector = vectors .vectorValue (i );
102- for (int j = 0 ; j < centroids .length ; j ++) {
103- float [] centroid = centroids [j ];
104- float d = VectorUtil .squareDistance (vector , centroid );
105- if (d < smallest ) {
106- smallest = d ;
107- centroidIdx = j ;
108- }
109- }
110- centroidVectorCount [centroidIdx ]++;
111- for (int j = 0 ; j < dimension ; j ++) {
112- nextCentroids [centroidIdx ][j ] += vector [j ];
113- }
114- assignments [i ] = centroidIdx ;
115- }
116-
117- // update centroids based on assignments of all vectors
118- for (int i = 0 ; i < centroids .length ; i ++) {
119- if (centroidVectorCount [i ] > 0 ) {
120- for (int j = 0 ; j < dimension ; j ++) {
121- centroids [i ][j ] = nextCentroids [i ][j ] / centroidVectorCount [i ];
122- }
123- }
95+ for (int assigment : assignments ) {
96+ centroidVectorCount [assigment ]++;
12497 }
12598
12699 int effectiveK = 0 ;
@@ -130,8 +103,6 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
130103 }
131104 }
132105
133- kMeansIntermediate = new KMeansIntermediate (centroids , assignments , vectors ::ordToDoc );
134-
135106 if (effectiveK == 1 ) {
136107 return kMeansIntermediate ;
137108 }
0 commit comments