@@ -33,20 +33,12 @@ class KMeansLocal {
3333 // second closest centroid.
3434 private static final float SOAR_MIN_DISTANCE = 1e-16f ;
3535
36- final int sampleSize ;
37- final int maxIterations ;
38- final int clustersPerNeighborhood ;
39- final float soarLambda ;
36+ private final int sampleSize ;
37+ private final int maxIterations ;
4038
41- KMeansLocal (int sampleSize , int maxIterations , int clustersPerNeighborhood , float soarLambda ) {
39+ KMeansLocal (int sampleSize , int maxIterations ) {
4240 this .sampleSize = sampleSize ;
4341 this .maxIterations = maxIterations ;
44- this .clustersPerNeighborhood = clustersPerNeighborhood ;
45- this .soarLambda = soarLambda ;
46- }
47-
48- KMeansLocal (int sampleSize , int maxIterations ) {
49- this (sampleSize , maxIterations , -1 , -1f );
5042 }
5143
5244 /**
@@ -179,8 +171,13 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,
179171 }
180172 }
181173
182- private int [] assignSpilled (FloatVectorValues vectors , List <int []> neighborhoods , float [][] centroids , int [] assignments )
183- throws IOException {
174+ private int [] assignSpilled (
175+ FloatVectorValues vectors ,
176+ List <int []> neighborhoods ,
177+ float [][] centroids ,
178+ int [] assignments ,
179+ float soarLambda
180+ ) throws IOException {
184181 // SOAR uses an adjusted distance for assigning spilled documents which is
185182 // given by:
186183 //
@@ -238,7 +235,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
238235 }
239236
240237 /**
241- * cluster using a lloyd k-means algorithm that is not neighbor aware
238+ * cluster using a lloyd k-means algorithm that does not consider prior clustered neighborhoods when adjusting centroids
242239 *
243240 * @param vectors the vectors to cluster
244241 * @param kMeansIntermediate the output object to populate which minimally includes centroids,
@@ -247,7 +244,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
247244 * @throws IOException is thrown if vectors is inaccessible
248245 */
249246 void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate ) throws IOException {
250- cluster (vectors , kMeansIntermediate , false );
247+ doCluster (vectors , kMeansIntermediate , - 1 , - 1 );
251248 }
252249
253250 /**
@@ -259,13 +256,23 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
259256 * the prior assignments of the given vectors; care should be taken in
260257 * passing in a valid output object with a centroids array that is the size of centroids expected
261258 * and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
262- * @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
263- * implies SOAR assignments
264- * @throws IOException is thrown if vectors is inaccessible
259+ * @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions.
260+ * @param soarLambda lambda used for SOAR assignments
261+ *
262+ * @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
265263 */
266- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , boolean neighborAware ) throws IOException {
264+ void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
265+ throws IOException {
266+ if (clustersPerNeighborhood < 2 ) {
267+ throw new IllegalArgumentException ("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]" );
268+ }
269+ doCluster (vectors , kMeansIntermediate , clustersPerNeighborhood , soarLambda );
270+ }
271+
272+ private void doCluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , int clustersPerNeighborhood , float soarLambda )
273+ throws IOException {
267274 float [][] centroids = kMeansIntermediate .centroids ();
268- boolean computeNeighborhoods = neighborAware && clustersPerNeighborhood > 0 ;
275+ boolean computeNeighborhoods = clustersPerNeighborhood != - 1 ;
269276
270277 List <int []> neighborhoods = null ;
271278 if (computeNeighborhoods ) {
@@ -281,7 +288,7 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
281288 int [] assignments = kMeansIntermediate .assignments ();
282289 assert assignments != null ;
283290 assert assignments .length == vectors .size ();
284- kMeansIntermediate .setSoarAssignments (assignSpilled (vectors , neighborhoods , centroids , assignments ));
291+ kMeansIntermediate .setSoarAssignments (assignSpilled (vectors , neighborhoods , centroids , assignments , soarLambda ));
285292 }
286293 }
287294
0 commit comments