@@ -179,12 +179,26 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
179179 }
180180 }
181181
182+ float [] scores = new float [clustersPerNeighborhood ];
182183 for (int i = 0 ; i < k ; i ++) {
183184 NeighborQueue queue = neighborQueues .get (i );
184185 int neighborCount = queue .size ();
185186 int [] neighbors = new int [neighborCount ];
186- float [] scores = new float [clustersPerNeighborhood ];
187- float maxIntraDistance = queue .consumeNodesWithWorstScore (neighbors );
187+ float maxIntraDistance = queue .consumeNodesWithWorstScore (neighbors , scores );
188+ // Sort neighbors by their score
189+ for (int j = 0 ; j < neighborCount ; j ++) {
190+ for (int l = j + 1 ; l < neighborCount ; l ++) {
191+ if (scores [j ] > scores [l ]) {
192+ // swap
193+ int tmp = neighbors [j ];
194+ neighbors [j ] = neighbors [l ];
195+ neighbors [l ] = tmp ;
196+ float tmpScore = scores [j ];
197+ scores [j ] = scores [l ];
198+ scores [l ] = tmpScore ;
199+ }
200+ }
201+ }
188202 NeighborHood neighborHood = new NeighborHood (neighbors , maxIntraDistance );
189203 neighborhoods .set (i , neighborHood );
190204 }
@@ -211,7 +225,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
211225 float [] currentCentroid = centroids [currAssignment ];
212226
213227 // TODO: cache these?
214- // float vectorCentroidDist = assignmentDistances[i];
215228 float vectorCentroidDist = VectorUtil .squareDistance (vector , currentCentroid );
216229
217230 if (vectorCentroidDist > SOAR_MIN_DISTANCE ) {
@@ -223,24 +236,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
223236
224237 int bestAssignment = -1 ;
225238 float minSoar = Float .MAX_VALUE ;
226- assert neighborhoods .get (currAssignment ) != null ;
227- for (int neighbor : neighborhoods .get (currAssignment ).neighbors ()) {
228- if (neighbor == currAssignment ) {
229- continue ;
239+ int centroidCount = centroids .length ;
240+ IntToIntFunction centroidOrds = c -> c ;
241+ if (neighborhoods != null ) {
242+ assert neighborhoods .get (currAssignment ) != null ;
243+ NeighborHood neighborhood = neighborhoods .get (currAssignment );
244+ centroidCount = neighborhood .neighbors .length ;
245+ centroidOrds = c -> neighborhood .neighbors [c ];
246+ }
247+ for (int j = 0 ; j < centroidCount ; j ++) {
248+ int centroidOrd = centroidOrds .apply (j );
249+ if (centroidOrd == currAssignment ) {
250+ continue ; // skip the current assignment
230251 }
231- float [] neighborCentroid = centroids [neighbor ];
232- final float soar ;
252+ float [] centroid = centroids [centroidOrd ];
253+ float soar ;
233254 if (vectorCentroidDist > SOAR_MIN_DISTANCE ) {
234- soar = ESVectorUtil .soarDistance (vector , neighborCentroid , diffs , soarLambda , vectorCentroidDist );
255+ soar = ESVectorUtil .soarDistance (vector , centroid , diffs , soarLambda , vectorCentroidDist );
235256 } else {
236257 // if the vector is very close to the centroid, we look for the second-nearest centroid
237- soar = VectorUtil .squareDistance (vector , neighborCentroid );
258+ soar = VectorUtil .squareDistance (vector , centroid );
238259 }
239260 if (soar < minSoar ) {
240- bestAssignment = neighbor ;
241261 minSoar = soar ;
262+ bestAssignment = centroidOrd ;
242263 }
243264 }
265+
244266 assert bestAssignment != -1 : "Failed to assign soar vector to centroid" ;
245267 spilledAssignments [i ] = bestAssignment ;
246268 }
@@ -280,7 +302,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
280302 float [][] centroids = kMeansIntermediate .centroids ();
281303
282304 List <NeighborHood > neighborhoods = null ;
283- if (neighborAware ) {
305+ // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
306+ if (neighborAware && centroids .length > clustersPerNeighborhood ) {
284307 int k = centroids .length ;
285308 neighborhoods = new ArrayList <>(k );
286309 for (int i = 0 ; i < k ; ++i ) {
0 commit comments