Skip to content

Commit 3893098

Browse files
committed
switched to reservoir sampling
1 parent f5f0538 commit 3893098

File tree

3 files changed

+17
-53
lines changed

3 files changed

+17
-53
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeans.java

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
import org.apache.lucene.util.VectorUtil;
1414

1515
import java.io.IOException;
16-
import java.util.ArrayList;
17-
import java.util.Collections;
18-
import java.util.List;
1916
import java.util.Random;
2017

2118
/**
@@ -31,59 +28,26 @@ class KMeans {
3128
this.maxIterations = maxIterations;
3229
}
3330

34-
// FIXME: use me or remove me
35-
private static void shuffle(int[] items, Random random) {
36-
if (items == null || items.length < 2) {
37-
return;
38-
}
39-
40-
for (int i = items.length - 1; i > 0; i--) {
41-
int index = random.nextInt(i + 1);
42-
int temp = items[i];
43-
items[i] = items[index];
44-
items[index] = temp;
45-
}
46-
}
47-
4831
/**
49-
* uses a FORGY approach to picking the initial centroids which are subsequently expected to be used by a clustering algorithm
32+
* uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected
33+
* to be used by a clustering algorithm
5034
*
5135
* @param vectors used to pick an initial set of random centroids
52-
* @param sampleSize the total number of vectors to be used as part of the sample for centroids
5336
* @param centroidCount the total number of centroids to pick
5437
* @return randomly selected centroids that are the min of centroidCount and sampleSize
5538
* @throws IOException is thrown if vectors is inaccessible
5639
*/
57-
static float[][] pickInitialCentroids(FloatVectorValues vectors, int sampleSize, int centroidCount) throws IOException {
58-
// Choose data points as random ensuring we have distinct points where possible
59-
60-
// FIXME: use me or remove me
61-
// int[] candidates = IntStream.range(0, sampleSize).toArray();
62-
// shuffle(candidates, new Random(42L));
63-
64-
List<Integer> candidates = new ArrayList<>(sampleSize);
65-
for (int i = 0; i < sampleSize; i++) {
66-
candidates.add(i);
67-
}
68-
Collections.shuffle(candidates, new Random(42L));
69-
70-
float[][] centroids = new float[centroidCount][vectors.dimension()];
71-
int centroidIdx = 0;
72-
for (int i = 0; i < candidates.size() && centroidIdx < centroidCount; i++) {
73-
int cand = candidates.get(i);
74-
float[] vector = vectors.vectorValue(cand);
75-
boolean goodCandidate = true;
76-
if (((candidates.size() - i) - (centroidCount - centroidIdx)) > 0) {
77-
for (int j = 0; j < centroidIdx; j++) {
78-
if ((VectorUtil.squareDistance(vector, centroids[j]) > 0.0f) == false) {
79-
goodCandidate = false;
80-
break;
81-
}
82-
}
83-
}
84-
if (goodCandidate) {
85-
System.arraycopy(vector, 0, centroids[centroidIdx], 0, vector.length);
86-
centroidIdx++;
40+
static float[][] pickInitialCentroids(FloatVectorValues vectors, int m, int centroidCount) throws IOException {
41+
Random random = new Random(42L);
42+
int centroidsSize = Math.min(vectors.size(), centroidCount);
43+
float[][] centroids = new float[centroidsSize][vectors.dimension()];
44+
for (int i = 0; i < vectors.size(); i++) {
45+
float[] vector = vectors.vectorValue(i);
46+
if (i < centroidCount) {
47+
System.arraycopy(vector, 0, centroids[i], 0, vector.length);
48+
} else if (random.nextDouble() < centroidCount * (1.0 / i)) {
49+
int c = random.nextInt(centroidCount);
50+
System.arraycopy(vector, 0, centroids[c], 0, vector.length);
8751
}
8852
}
8953
return centroids;

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public void testKMeansLocal() throws IOException {
2828
short clustersPerNeighborhood = (short) random().nextInt(0, 512);
2929
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
3030

31-
float[][] centroids = KMeans.pickInitialCentroids(vectors, sampleSize, nClusters);
31+
float[][] centroids = KMeans.pickInitialCentroids(vectors, nClusters);
3232
KMeans.cluster(vectors, centroids, sampleSize, maxIterations);
3333

3434
int[] assignments = new int[vectors.size()];
@@ -68,7 +68,7 @@ public void testKMeansLocalAllZero() throws IOException {
6868
int sampleSize = vectors.size();
6969
FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5);
7070

71-
float[][] centroids = KMeans.pickInitialCentroids(fvv, sampleSize, nClusters);
71+
float[][] centroids = KMeans.pickInitialCentroids(fvv, nClusters);
7272
KMeans.cluster(fvv, centroids, sampleSize, maxIterations);
7373

7474
int[] assignments = new int[vectors.size()];

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public void testKMeans() throws IOException {
2626
int maxIterations = random().nextInt(0, 100);
2727
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
2828

29-
float[][] centroids = KMeans.pickInitialCentroids(vectors, sampleSize, nClusters);
29+
float[][] centroids = KMeans.pickInitialCentroids(vectors, nClusters);
3030
KMeans.cluster(vectors, centroids, sampleSize, maxIterations);
3131

3232
assertEquals(nClusters, centroids.length);
@@ -43,7 +43,7 @@ public void testKMeansAllZero() throws IOException {
4343
}
4444
int sampleSize = vectors.size();
4545
FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5);
46-
float[][] centroids = KMeans.pickInitialCentroids(fvv, sampleSize, nClusters);
46+
float[][] centroids = KMeans.pickInitialCentroids(fvv, nClusters);
4747
KMeans.cluster(fvv, centroids, sampleSize, maxIterations);
4848

4949
assertEquals(nClusters, centroids.length);

0 commit comments

Comments
 (0)