@@ -82,7 +82,7 @@ private static boolean stepLloyd(
8282        float [][] centroids ,
8383        float [][] nextCentroids ,
8484        int [] assignments ,
85-         List <int [] > neighborhoods 
85+         List <NeighborHood > neighborhoods 
8686    ) throws  IOException  {
8787        boolean  changed  = false ;
8888        int  dim  = vectors .dimension ();
@@ -124,11 +124,20 @@ private static boolean stepLloyd(
124124        return  changed ;
125125    }
126126
127-     private  static  int  getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int  centroidIdx , int []  centroidOffsets ) {
127+     private  static  int  getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int  centroidIdx , NeighborHood   neighborhood ) {
128128        int  bestCentroidOffset  = centroidIdx ;
129129        assert  centroidIdx  >= 0  && centroidIdx  < centroids .length ;
130130        float  minDsq  = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
131-         for  (int  offset  : centroidOffsets ) {
131+         for  (int  i  = 0 ; i  < neighborhood .neighbors .length ; i ++) {
132+             int  offset  = neighborhood .neighbors [i ];
133+             // float score = neighborhood.scores[i]; 
134+             assert  offset  >= 0  && offset  < centroids .length  : "Invalid neighbor offset: "  + offset ;
135+             if  (minDsq  < neighborhood .maxIntraDistance ) {
136+                 // if the distance found is smaller than the maximum intra-cluster distance 
137+                 // we don't consider it for further re-assignment 
138+                 return  bestCentroidOffset ;
139+             }
140+             // compute the distance to the centroid 
132141            float  dsq  = VectorUtil .squareDistance (vector , centroids [offset ]);
133142            if  (dsq  < minDsq ) {
134143                minDsq  = dsq ;
@@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
151160        return  bestCentroidOffset ;
152161    }
153162
154-     private  void  computeNeighborhoods (float [][] centers , List <int [] > neighborhoods , int  clustersPerNeighborhood ) {
163+     private  void  computeNeighborhoods (float [][] centers , List <NeighborHood > neighborhoods , int  clustersPerNeighborhood ) {
155164        int  k  = neighborhoods .size ();
156165
157166        if  (k  == 0  || clustersPerNeighborhood  <= 0 ) {
@@ -172,14 +181,24 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
172181
173182        for  (int  i  = 0 ; i  < k ; i ++) {
174183            NeighborQueue  queue  = neighborQueues .get (i );
175-             int  neighborCount  = queue .size ();
176-             int [] neighbors  = new  int [neighborCount ];
177-             queue .consumeNodes (neighbors );
178-             neighborhoods .set (i , neighbors );
184+             if  (queue .size () == 0 ) {
185+                 // no neighbors, skip 
186+                 neighborhoods .set (i , NeighborHood .EMPTY );
187+                 continue ;
188+             }
189+             // consume the queue into the neighbors array and get the maximum intra-cluster distance 
190+             int [] neighbors  = new  int [queue .size ()];
191+             float  maxIntraDistance  = queue .topScore ();
192+             int  iter  = 0 ;
193+             while  (queue .size () > 0 ) {
194+                 neighbors [neighbors .length  - ++iter ] = queue .pop ();
195+             }
196+             NeighborHood  neighborHood  = new  NeighborHood (neighbors , maxIntraDistance );
197+             neighborhoods .set (i , neighborHood );
179198        }
180199    }
181200
182-     private  int [] assignSpilled (FloatVectorValues  vectors , List <int [] > neighborhoods , float [][] centroids , int [] assignments )
201+     private  int [] assignSpilled (FloatVectorValues  vectors , List <NeighborHood > neighborhoods , float [][] centroids , int [] assignments )
183202        throws  IOException  {
184203        // SOAR uses an adjusted distance for assigning spilled documents which is 
185204        // given by: 
@@ -200,7 +219,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
200219            float [] currentCentroid  = centroids [currAssignment ];
201220
202221            // TODO: cache these? 
203-             // float vectorCentroidDist = assignmentDistances[i]; 
204222            float  vectorCentroidDist  = VectorUtil .squareDistance (vector , currentCentroid );
205223
206224            if  (vectorCentroidDist  > SOAR_MIN_DISTANCE ) {
@@ -212,24 +230,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
212230
213231            int  bestAssignment  = -1 ;
214232            float  minSoar  = Float .MAX_VALUE ;
215-             assert  neighborhoods .get (currAssignment ) != null ;
216-             for  (int  neighbor  : neighborhoods .get (currAssignment )) {
217-                 if  (neighbor  == currAssignment ) {
218-                     continue ;
233+             int  centroidCount  = centroids .length ;
234+             IntToIntFunction  centroidOrds  = c  -> c ;
235+             if  (neighborhoods  != null ) {
236+                 assert  neighborhoods .get (currAssignment ) != null ;
237+                 NeighborHood  neighborhood  = neighborhoods .get (currAssignment );
238+                 centroidCount  = neighborhood .neighbors .length ;
239+                 centroidOrds  = c  -> neighborhood .neighbors [c ];
240+             }
241+             for  (int  j  = 0 ; j  < centroidCount ; j ++) {
242+                 int  centroidOrd  = centroidOrds .apply (j );
243+                 if  (centroidOrd  == currAssignment ) {
244+                     continue ; // skip the current assignment 
219245                }
220-                 float [] neighborCentroid  = centroids [neighbor ];
221-                 final   float  soar ;
246+                 float [] centroid  = centroids [centroidOrd ];
247+                 float  soar ;
222248                if  (vectorCentroidDist  > SOAR_MIN_DISTANCE ) {
223-                     soar  = ESVectorUtil .soarDistance (vector , neighborCentroid , diffs , soarLambda , vectorCentroidDist );
249+                     soar  = ESVectorUtil .soarDistance (vector , centroid , diffs , soarLambda , vectorCentroidDist );
224250                } else  {
225251                    // if the vector is very close to the centroid, we look for the second-nearest centroid 
226-                     soar  = VectorUtil .squareDistance (vector , neighborCentroid );
252+                     soar  = VectorUtil .squareDistance (vector , centroid );
227253                }
228254                if  (soar  < minSoar ) {
229-                     bestAssignment  = neighbor ;
230255                    minSoar  = soar ;
256+                     bestAssignment  = centroidOrd ;
231257                }
232258            }
259+ 
233260            assert  bestAssignment  != -1  : "Failed to assign soar vector to centroid" ;
234261            spilledAssignments [i ] = bestAssignment ;
235262        }
@@ -250,6 +277,10 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
250277        cluster (vectors , kMeansIntermediate , false );
251278    }
252279
280+     record  NeighborHood (int [] neighbors , float  maxIntraDistance ) {
281+         static  final  NeighborHood  EMPTY  = new  NeighborHood (new  int [0 ], Float .POSITIVE_INFINITY );
282+     }
283+ 
253284    /** 
254285     * cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids 
255286     * this also is used to generate the neighborhood aware additional (SOAR) assignments 
@@ -266,8 +297,9 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
266297    void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , boolean  neighborAware ) throws  IOException  {
267298        float [][] centroids  = kMeansIntermediate .centroids ();
268299
269-         List <int []> neighborhoods  = null ;
270-         if  (neighborAware ) {
300+         List <NeighborHood > neighborhoods  = null ;
301+         // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering 
302+         if  (neighborAware  && centroids .length  > clustersPerNeighborhood ) {
271303            int  k  = centroids .length ;
272304            neighborhoods  = new  ArrayList <>(k );
273305            for  (int  i  = 0 ; i  < k ; ++i ) {
@@ -284,7 +316,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
284316        }
285317    }
286318
287-     private  void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , List <int []> neighborhoods ) throws  IOException  {
319+     private  void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , List <NeighborHood > neighborhoods )
320+         throws  IOException  {
288321        float [][] centroids  = kMeansIntermediate .centroids ();
289322        int  k  = centroids .length ;
290323        int  n  = vectors .size ();
0 commit comments