Skip to content

Commit 2264a46

Browse files
committed
address review comments
1 parent 651e3bf commit 2264a46

File tree

3 files changed

+129
-76
lines changed

3 files changed

+129
-76
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,6 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
210210
return bestCentroidOffset;
211211
}
212212

213-
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
214-
assert centers.length > clustersPerNeighborhood;
215-
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
216-
if (centers.length < 15_000) {
217-
return NeighborHood.computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
218-
} else {
219-
return NeighborHood.computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
220-
}
221-
}
222-
223213
private void assignSpilled(
224214
FloatVectorValues vectors,
225215
KMeansIntermediate kmeansIntermediate,
@@ -350,7 +340,7 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter
350340
NeighborHood[] neighborhoods = null;
351341
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
352342
if (neighborAware && centroids.length > clustersPerNeighborhood) {
353-
neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood);
343+
neighborhoods = NeighborHood.computeNeighborhoods(centroids, clustersPerNeighborhood);
354344
}
355345
cluster(vectors, kMeansIntermediate, neighborhoods);
356346
if (neighborAware && soarLambda >= 0) {

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java

Lines changed: 125 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
import org.apache.lucene.index.VectorSimilarityFunction;
1313
import org.apache.lucene.search.KnnCollector;
14-
import org.apache.lucene.search.ScoreDoc;
15-
import org.apache.lucene.util.Bits;
14+
import org.apache.lucene.search.TopDocs;
15+
import org.apache.lucene.search.knn.KnnSearchStrategy;
1616
import org.apache.lucene.util.VectorUtil;
1717
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
1818
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
@@ -25,8 +25,66 @@
2525

2626
public record NeighborHood(int[] neighbors, float maxIntraDistance) {
2727

28+
private static final int M = 8;
29+
private static final int EF_CONSTRUCTION = 150;
30+
2831
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
2932

33+
public static NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
34+
assert centers.length > clustersPerNeighborhood;
35+
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
36+
if (centers.length < 15_000) {
37+
return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
38+
} else {
39+
return computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
40+
}
41+
}
42+
43+
public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
44+
int k = centers.length;
45+
NeighborQueue[] neighborQueues = new NeighborQueue[k];
46+
for (int i = 0; i < k; i++) {
47+
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
48+
}
49+
final float[] scores = new float[4];
50+
final int limit = k - 3;
51+
for (int i = 0; i < k - 1; i++) {
52+
float[] center = centers[i];
53+
int j = i + 1;
54+
for (; j < limit; j += 4) {
55+
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
56+
for (int h = 0; h < 4; h++) {
57+
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
58+
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
59+
}
60+
}
61+
for (; j < k; j++) {
62+
float dsq = VectorUtil.squareDistance(center, centers[j]);
63+
neighborQueues[j].insertWithOverflow(i, dsq);
64+
neighborQueues[i].insertWithOverflow(j, dsq);
65+
}
66+
}
67+
68+
NeighborHood[] neighborhoods = new NeighborHood[k];
69+
for (int i = 0; i < k; i++) {
70+
NeighborQueue queue = neighborQueues[i];
71+
if (queue.size() == 0) {
72+
// no neighbors, skip
73+
neighborhoods[i] = NeighborHood.EMPTY;
74+
continue;
75+
}
76+
// consume the queue into the neighbors array and get the maximum intra-cluster distance
77+
int[] neighbors = new int[queue.size()];
78+
float maxIntraDistance = queue.topScore();
79+
int iter = 0;
80+
while (queue.size() > 0) {
81+
neighbors[neighbors.length - ++iter] = queue.pop();
82+
}
83+
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
84+
}
85+
return neighborhoods;
86+
}
87+
3088
public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
3189
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
3290
int scoringOrdinal;
@@ -57,92 +115,97 @@ public RandomVectorScorerSupplier copy() {
57115
return this;
58116
}
59117
};
60-
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length);
118+
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, M, EF_CONSTRUCTION, 42L).build(centers.length);
61119
final NeighborHood[] neighborhoods = new NeighborHood[centers.length];
62-
final SingleBit singleBit = new SingleBit(centers.length);
120+
// oversample the number of neighbors we collect to improve recall
121+
final ReusableKnnCollector collector = new ReusableKnnCollector(2 * clustersPerNeighborhood);
63122
for (int i = 0; i < centers.length; i++) {
123+
collector.reset(i);
64124
scorer.setScoringOrdinal(i);
65-
singleBit.indexSet = i;
66-
final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE);
67-
final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
68-
if (scoreDocs.length == 0) {
125+
HnswGraphSearcher.search(scorer, collector, graph, null);
126+
NeighborQueue queue = collector.queue;
127+
if (queue.size() == 0) {
69128
// no neighbors, skip
70129
neighborhoods[i] = NeighborHood.EMPTY;
71130
continue;
72131
}
73-
final int[] neighbors = new int[scoreDocs.length];
74-
for (int j = 0; j < neighbors.length; j++) {
75-
neighbors[j] = scoreDocs[j].doc;
76-
assert neighbors[j] != i;
132+
while (queue.size() > clustersPerNeighborhood) {
133+
queue.pop();
134+
}
135+
final float minScore = queue.topScore();
136+
final int[] neighbors = new int[queue.size()];
137+
for (int j = 1; j <= neighbors.length; j++) {
138+
neighbors[neighbors.length - j] = queue.pop();
77139
}
78-
final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1;
79-
neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity);
140+
neighborhoods[i] = new NeighborHood(neighbors, (1f / minScore) - 1);
80141
}
81142
return neighborhoods;
82143
}
83144

84-
private static class SingleBit implements Bits {
145+
private static class ReusableKnnCollector implements KnnCollector {
85146

86-
private final int length;
87-
private int indexSet;
147+
private final NeighborQueue queue;
148+
private final int k;
149+
int visitedCount;
150+
int currenOrd;
88151

89-
SingleBit(int length) {
90-
this.length = length;
152+
ReusableKnnCollector(int k) {
153+
this.k = k;
154+
this.queue = new NeighborQueue(k, false);
155+
}
156+
157+
void reset(int ord) {
158+
queue.clear();
159+
visitedCount = 0;
160+
currenOrd = ord;
91161
}
92162

93163
@Override
94-
public boolean get(int index) {
95-
return index != indexSet;
164+
public boolean earlyTerminated() {
165+
return false;
96166
}
97167

98168
@Override
99-
public int length() {
100-
return length;
169+
public void incVisitedCount(int count) {
170+
visitedCount += count;
101171
}
102-
}
103172

104-
public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
105-
int k = centers.length;
106-
NeighborQueue[] neighborQueues = new NeighborQueue[k];
107-
for (int i = 0; i < k; i++) {
108-
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
173+
@Override
174+
public long visitedCount() {
175+
return visitedCount;
109176
}
110-
final float[] scores = new float[4];
111-
final int limit = k - 3;
112-
for (int i = 0; i < k - 1; i++) {
113-
float[] center = centers[i];
114-
int j = i + 1;
115-
for (; j < limit; j += 4) {
116-
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
117-
for (int h = 0; h < 4; h++) {
118-
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
119-
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
120-
}
121-
}
122-
for (; j < k; j++) {
123-
float dsq = VectorUtil.squareDistance(center, centers[j]);
124-
neighborQueues[j].insertWithOverflow(i, dsq);
125-
neighborQueues[i].insertWithOverflow(j, dsq);
126-
}
177+
178+
@Override
179+
public long visitLimit() {
180+
return Integer.MAX_VALUE;
127181
}
128182

129-
NeighborHood[] neighborhoods = new NeighborHood[k];
130-
for (int i = 0; i < k; i++) {
131-
NeighborQueue queue = neighborQueues[i];
132-
if (queue.size() == 0) {
133-
// no neighbors, skip
134-
neighborhoods[i] = NeighborHood.EMPTY;
135-
continue;
136-
}
137-
// consume the queue into the neighbors array and get the maximum intra-cluster distance
138-
int[] neighbors = new int[queue.size()];
139-
float maxIntraDistance = queue.topScore();
140-
int iter = 0;
141-
while (queue.size() > 0) {
142-
neighbors[neighbors.length - ++iter] = queue.pop();
183+
@Override
184+
public int k() {
185+
return k;
186+
}
187+
188+
@Override
189+
public boolean collect(int docId, float similarity) {
190+
if (currenOrd != docId) {
191+
return queue.insertWithOverflow(docId, similarity);
143192
}
144-
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
193+
return false;
194+
}
195+
196+
@Override
197+
public float minCompetitiveSimilarity() {
198+
return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY;
199+
}
200+
201+
@Override
202+
public TopDocs topDocs() {
203+
throw new UnsupportedOperationException();
204+
}
205+
206+
@Override
207+
public KnnSearchStrategy getSearchStrategy() {
208+
return null;
145209
}
146-
return neighborhoods;
147210
}
148211
}

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,23 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
146146
}
147147

148148
public void testComputeNeighbours() throws IOException {
149-
int numCentroids = randomIntBetween(100, 1000);
149+
int numCentroids = randomIntBetween(1000, 2000);
150150
int dims = randomIntBetween(10, 200);
151151
float[][] vectors = new float[numCentroids][dims];
152152
for (int i = 0; i < numCentroids; i++) {
153153
for (int j = 0; j < dims; j++) {
154154
vectors[i][j] = randomFloat();
155155
}
156156
}
157-
int clustersPerNeighbour = randomIntBetween(32, 64);
157+
int clustersPerNeighbour = randomIntBetween(32, 128);
158158
NeighborHood[] neighborHoodsGraph = NeighborHood.computeNeighborhoodsGraph(vectors, clustersPerNeighbour);
159159
NeighborHood[] neighborHoodsBruteForce = NeighborHood.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour);
160160
assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length);
161161
for (int i = 0; i < neighborHoodsGraph.length; i++) {
162162
assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length);
163163
int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors());
164164
double recall = (double) matched / neighborHoodsGraph[i].neighbors().length;
165-
assertThat(recall, greaterThanOrEqualTo(0.7));
165+
assertThat(recall, greaterThanOrEqualTo(0.5));
166166
if (recall == 1.0) {
167167
// we cannot assert on array equality as there can be small differences due to numerical errors
168168
assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f);

0 commit comments

Comments
 (0)