Skip to content

Commit ab150c8

Browse files
committed
[DiskBBQ] Use a hnsw graph to compute neighbours
1 parent 3eefef7 commit ab150c8

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,18 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.index.VectorSimilarityFunction;
14+
import org.apache.lucene.search.KnnCollector;
15+
import org.apache.lucene.search.ScoreDoc;
16+
import org.apache.lucene.util.Bits;
1317
import org.apache.lucene.util.FixedBitSet;
1418
import org.apache.lucene.util.VectorUtil;
19+
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
20+
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
1521
import org.apache.lucene.util.hnsw.IntToIntFunction;
22+
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
23+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
24+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
1625
import org.elasticsearch.index.codec.vectors.SampleReader;
1726
import org.elasticsearch.simdvec.ESVectorUtil;
1827

@@ -210,9 +219,92 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
210219
return bestCentroidOffset;
211220
}
212221

213-
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
222+
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
223+
assert centers.length > clustersPerNeighborhood;
224+
// experiments shows that below 20k, we better use brute force, otherwise hnsw gives us a nice speed up
225+
if (centers.length < 20_000) {
226+
return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
227+
} else {
228+
return computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
229+
}
230+
}
231+
232+
static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
233+
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
234+
int scoringOrdinal;
235+
236+
@Override
237+
public float score(int node) {
238+
return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]);
239+
}
240+
241+
@Override
242+
public int maxOrd() {
243+
return centers.length;
244+
}
245+
246+
@Override
247+
public void setScoringOrdinal(int node) {
248+
scoringOrdinal = node;
249+
}
250+
};
251+
final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() {
252+
@Override
253+
public UpdateableRandomVectorScorer scorer() {
254+
return scorer;
255+
}
256+
257+
@Override
258+
public RandomVectorScorerSupplier copy() {
259+
return this;
260+
}
261+
};
262+
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length);
263+
final NeighborHood[] neighborhoods = new NeighborHood[centers.length];
264+
final SingleBit singleBit = new SingleBit(centers.length);
265+
for (int i = 0; i < centers.length; i++) {
266+
scorer.setScoringOrdinal(i);
267+
singleBit.indexSet = i;
268+
final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE);
269+
final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
270+
if (scoreDocs.length == 0) {
271+
// no neighbors, skip
272+
neighborhoods[i] = NeighborHood.EMPTY;
273+
continue;
274+
}
275+
final int[] neighbors = new int[scoreDocs.length];
276+
for (int j = 0; j < neighbors.length; j++) {
277+
neighbors[j] = scoreDocs[j].doc;
278+
assert neighbors[j] != i;
279+
}
280+
final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1;
281+
neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity);
282+
}
283+
return neighborhoods;
284+
}
285+
286+
private static class SingleBit implements Bits {
287+
288+
private final int length;
289+
private int indexSet;
290+
291+
SingleBit(int length) {
292+
this.length = length;
293+
}
294+
295+
@Override
296+
public boolean get(int index) {
297+
return index != indexSet;
298+
}
299+
300+
@Override
301+
public int length() {
302+
return length;
303+
}
304+
}
305+
306+
static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
214307
int k = centers.length;
215-
assert k > clustersPerNeighborhood;
216308
NeighborQueue[] neighborQueues = new NeighborQueue[k];
217309
for (int i = 0; i < k; i++) {
218310
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import java.io.IOException;
1717
import java.util.ArrayList;
18+
import java.util.HashSet;
1819
import java.util.List;
20+
import java.util.Set;
1921

2022
import static org.hamcrest.Matchers.containsString;
23+
import static org.hamcrest.Matchers.greaterThan;
2124

2225
public class KMeansLocalTests extends ESTestCase {
2326

@@ -141,4 +144,47 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
141144
}
142145
return FloatVectorValues.fromFloats(vectors, nDims);
143146
}
147+
148+
public void testComputeNeighbours() throws IOException {
149+
int numCentroids = randomIntBetween(100, 10000);
150+
int dims = randomIntBetween(10, 200);
151+
float[][] vectors = new float[numCentroids][dims];
152+
for (int i = 0; i < numCentroids; i++) {
153+
for (int j = 0; j < dims; j++) {
154+
vectors[i][j] = randomFloat();
155+
}
156+
}
157+
int clustersPerNeighbour = randomIntBetween(6, 32);
158+
KMeansLocal.NeighborHood[] neighborHoodsGraph = KMeansLocal.computeNeighborhoodsGraph(vectors, clustersPerNeighbour);
159+
KMeansLocal.NeighborHood[] neighborHoodsBruteForce = KMeansLocal.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour);
160+
assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length);
161+
for (int i = 0; i < neighborHoodsGraph.length; i++) {
162+
assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length);
163+
int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors());
164+
double recall = (double) matched / neighborHoodsGraph[i].neighbors().length;
165+
assertThat(recall, greaterThan(0.6));
166+
if (recall == 1.0) {
167+
// we cannot assert on array equality as there can be small differences due to numerical errors
168+
assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f);
169+
}
170+
}
171+
}
172+
173+
private static int compareNN(int currentId, int[] expected, int[] results) {
174+
int matched = 0;
175+
Set<Integer> expectedSet = new HashSet<>();
176+
Set<Integer> alreadySeen = new HashSet<>();
177+
for (int i : expected) {
178+
assertNotEquals(currentId, i);
179+
assertTrue(expectedSet.add(i));
180+
}
181+
for (int i : results) {
182+
assertNotEquals(currentId, i);
183+
assertTrue(alreadySeen.add(i));
184+
if (expectedSet.contains(i)) {
185+
++matched;
186+
}
187+
}
188+
return matched;
189+
}
144190
}

0 commit comments

Comments
 (0)