1313import org .apache .lucene .util .VectorUtil ;
1414
1515import java .io .IOException ;
16- import java .util .ArrayList ;
17- import java .util .Collections ;
18- import java .util .List ;
1916import java .util .Random ;
2017
2118/**
@@ -31,59 +28,26 @@ class KMeans {
3128 this .maxIterations = maxIterations ;
3229 }
3330
34- // FIXME: use me or remove me
35- private static void shuffle (int [] items , Random random ) {
36- if (items == null || items .length < 2 ) {
37- return ;
38- }
39-
40- for (int i = items .length - 1 ; i > 0 ; i --) {
41- int index = random .nextInt (i + 1 );
42- int temp = items [i ];
43- items [i ] = items [index ];
44- items [index ] = temp ;
45- }
46- }
47-
4831 /**
49- * uses a FORGY approach to picking the initial centroids which are subsequently expected to be used by a clustering algorithm
32+ * uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
33+ * to be used by a clustering algorithm
5034 *
5135 * @param vectors used to pick an initial set of random centroids
52- * @param sampleSize the total number of vectors to be used as part of the sample for centroids
5336 * @param centroidCount the total number of centroids to pick
5437 * @return randomly selected centroids that are the min of centroidCount and sampleSize
5538 * @throws IOException is thrown if vectors is inaccessible
5639 */
57- static float [][] pickInitialCentroids (FloatVectorValues vectors , int sampleSize , int centroidCount ) throws IOException {
58- // Choose data points as random ensuring we have distinct points where possible
59-
60- // FIXME: use me or remove me
61- // int[] candidates = IntStream.range(0, sampleSize).toArray();
62- // shuffle(candidates, new Random(42L));
63-
64- List <Integer > candidates = new ArrayList <>(sampleSize );
65- for (int i = 0 ; i < sampleSize ; i ++) {
66- candidates .add (i );
67- }
68- Collections .shuffle (candidates , new Random (42L ));
69-
70- float [][] centroids = new float [centroidCount ][vectors .dimension ()];
71- int centroidIdx = 0 ;
72- for (int i = 0 ; i < candidates .size () && centroidIdx < centroidCount ; i ++) {
73- int cand = candidates .get (i );
74- float [] vector = vectors .vectorValue (cand );
75- boolean goodCandidate = true ;
76- if (((candidates .size () - i ) - (centroidCount - centroidIdx )) > 0 ) {
77- for (int j = 0 ; j < centroidIdx ; j ++) {
78- if ((VectorUtil .squareDistance (vector , centroids [j ]) > 0.0f ) == false ) {
79- goodCandidate = false ;
80- break ;
81- }
82- }
83- }
84- if (goodCandidate ) {
85- System .arraycopy (vector , 0 , centroids [centroidIdx ], 0 , vector .length );
86- centroidIdx ++;
40+ static float [][] pickInitialCentroids (FloatVectorValues vectors , int m , int centroidCount ) throws IOException {
41+ Random random = new Random (42L );
42+ int centroidsSize = Math .min (vectors .size (), centroidCount );
43+ float [][] centroids = new float [centroidsSize ][vectors .dimension ()];
44+ for (int i = 0 ; i < vectors .size (); i ++) {
45+ float [] vector = vectors .vectorValue (i );
46+ if (i < centroidCount ) {
47+ System .arraycopy (vector , 0 , centroids [i ], 0 , vector .length );
48+ } else if (random .nextDouble () < centroidCount * (1.0 / i )) {
49+ int c = random .nextInt (centroidCount );
50+ System .arraycopy (vector , 0 , centroids [c ], 0 , vector .length );
8751 }
8852 }
8953 return centroids ;
0 commit comments