@@ -82,7 +82,7 @@ private static boolean stepLloyd(
8282        float [][] centroids ,
8383        float [][] nextCentroids ,
8484        int [] assignments ,
85-         List <int [] > neighborhoods 
85+         List <NeighborHood > neighborhoods 
8686    ) throws  IOException  {
8787        boolean  changed  = false ;
8888        int  dim  = vectors .dimension ();
@@ -124,11 +124,20 @@ private static boolean stepLloyd(
124124        return  changed ;
125125    }
126126
127-     private  static  int  getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int  centroidIdx , int []  centroidOffsets ) {
127+     private  static  int  getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int  centroidIdx , NeighborHood   neighborhood ) {
128128        int  bestCentroidOffset  = centroidIdx ;
129129        assert  centroidIdx  >= 0  && centroidIdx  < centroids .length ;
130130        float  minDsq  = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
131-         for  (int  offset  : centroidOffsets ) {
131+         for  (int  i  = 0 ; i  < neighborhood .neighbors .length ; i ++) {
132+             int  offset  = neighborhood .neighbors [i ];
133+             // float score = neighborhood.scores[i]; 
134+             assert  offset  >= 0  && offset  < centroids .length  : "Invalid neighbor offset: "  + offset ;
135+             if  (minDsq  < neighborhood .maxIntraDistance ) {
136+                 // if the distance found is smaller than the maximum intra-cluster distance 
137+                 // we don't consider it for further re-assignment 
138+                 return  bestCentroidOffset ;
139+             }
140+             // compute the distance to the centroid 
132141            float  dsq  = VectorUtil .squareDistance (vector , centroids [offset ]);
133142            if  (dsq  < minDsq ) {
134143                minDsq  = dsq ;
@@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
151160        return  bestCentroidOffset ;
152161    }
153162
154-     private  void  computeNeighborhoods (float [][] centers , List <int [] > neighborhoods , int  clustersPerNeighborhood ) {
163+     private  void  computeNeighborhoods (float [][] centers , List <NeighborHood > neighborhoods , int  clustersPerNeighborhood ) {
155164        int  k  = neighborhoods .size ();
156165
157166        if  (k  == 0  || clustersPerNeighborhood  <= 0 ) {
@@ -174,12 +183,14 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
174183            NeighborQueue  queue  = neighborQueues .get (i );
175184            int  neighborCount  = queue .size ();
176185            int [] neighbors  = new  int [neighborCount ];
177-             queue .consumeNodes (neighbors );
178-             neighborhoods .set (i , neighbors );
186+             float [] scores  = new  float [clustersPerNeighborhood ];
187+             float  maxIntraDistance  = queue .consumeNodesWithWorstScore (neighbors );
188+             NeighborHood  neighborHood  = new  NeighborHood (neighbors , maxIntraDistance );
189+             neighborhoods .set (i , neighborHood );
179190        }
180191    }
181192
182-     private  int [] assignSpilled (FloatVectorValues  vectors , List <int [] > neighborhoods , float [][] centroids , int [] assignments )
193+     private  int [] assignSpilled (FloatVectorValues  vectors , List <NeighborHood > neighborhoods , float [][] centroids , int [] assignments )
183194        throws  IOException  {
184195        // SOAR uses an adjusted distance for assigning spilled documents which is 
185196        // given by: 
@@ -213,7 +224,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
213224            int  bestAssignment  = -1 ;
214225            float  minSoar  = Float .MAX_VALUE ;
215226            assert  neighborhoods .get (currAssignment ) != null ;
216-             for  (int  neighbor  : neighborhoods .get (currAssignment )) {
227+             for  (int  neighbor  : neighborhoods .get (currAssignment ). neighbors () ) {
217228                if  (neighbor  == currAssignment ) {
218229                    continue ;
219230                }
@@ -250,6 +261,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
250261        cluster (vectors , kMeansIntermediate , false );
251262    }
252263
264+     record  NeighborHood (int [] neighbors , float  maxIntraDistance ) {}
265+ 
253266    /** 
254267     * cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids 
255268     * this also is used to generate the neighborhood aware additional (SOAR) assignments 
@@ -266,7 +279,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
266279    void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , boolean  neighborAware ) throws  IOException  {
267280        float [][] centroids  = kMeansIntermediate .centroids ();
268281
269-         List <int [] > neighborhoods  = null ;
282+         List <NeighborHood > neighborhoods  = null ;
270283        if  (neighborAware ) {
271284            int  k  = centroids .length ;
272285            neighborhoods  = new  ArrayList <>(k );
@@ -284,7 +297,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
284297        }
285298    }
286299
287-     private  void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , List <int []> neighborhoods ) throws  IOException  {
300+     private  void  cluster (FloatVectorValues  vectors , KMeansIntermediate  kMeansIntermediate , List <NeighborHood > neighborhoods )
301+         throws  IOException  {
288302        float [][] centroids  = kMeansIntermediate .centroids ();
289303        int  k  = centroids .length ;
290304        int  n  = vectors .size ();
0 commit comments