Skip to content

Commit 5034bca

Browse files
committed
add benchmark
1 parent 53f4bd6 commit 5034bca

File tree

3 files changed

+239
-146
lines changed

3 files changed

+239
-146
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.vector;
11+
12+
import org.elasticsearch.common.logging.LogConfigurator;
13+
import org.elasticsearch.index.codec.vectors.cluster.NeighborHood;
14+
import org.openjdk.jmh.annotations.Benchmark;
15+
import org.openjdk.jmh.annotations.BenchmarkMode;
16+
import org.openjdk.jmh.annotations.Fork;
17+
import org.openjdk.jmh.annotations.Measurement;
18+
import org.openjdk.jmh.annotations.Mode;
19+
import org.openjdk.jmh.annotations.OutputTimeUnit;
20+
import org.openjdk.jmh.annotations.Param;
21+
import org.openjdk.jmh.annotations.Scope;
22+
import org.openjdk.jmh.annotations.Setup;
23+
import org.openjdk.jmh.annotations.State;
24+
import org.openjdk.jmh.annotations.Warmup;
25+
import org.openjdk.jmh.infra.Blackhole;
26+
27+
import java.io.IOException;
28+
import java.util.Random;
29+
import java.util.concurrent.TimeUnit;
30+
31+
@BenchmarkMode(Mode.AverageTime)
32+
@OutputTimeUnit(TimeUnit.SECONDS)
33+
@State(Scope.Benchmark)
34+
// first iteration is complete garbage, so make sure we really warmup
35+
@Warmup(iterations = 1, time = 1)
36+
// real iterations. not useful to spend tons of time here, better to fork more
37+
@Measurement(iterations = 3, time = 1)
38+
// engage some noise reduction
39+
@Fork(value = 1)
40+
public class ComputeNeighboursBenchmark {
41+
42+
static {
43+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
44+
}
45+
46+
@Param({ "1000", "2000", "3000", "5000", "10000", "20000", "50000" })
47+
int numVectors;
48+
49+
@Param({ "384", "782", "1024" })
50+
int dims;
51+
52+
float[][] vectors;
53+
int clusterPerNeighbour = 128;
54+
55+
@Setup
56+
public void setup() throws IOException {
57+
Random random = new Random(123);
58+
vectors = new float[numVectors][dims];
59+
for (float[] vector : vectors) {
60+
for (int i = 0; i < dims; i++) {
61+
vector[i] = random.nextFloat();
62+
}
63+
}
64+
}
65+
66+
@Benchmark
67+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
68+
public void bruteForce(Blackhole bh) {
69+
bh.consume(NeighborHood.computeNeighborhoodsBruteForce(vectors, clusterPerNeighbour));
70+
}
71+
72+
@Benchmark
73+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
74+
public void graph(Blackhole bh) throws IOException {
75+
bh.consume(NeighborHood.computeNeighborhoodsGraph(vectors, clusterPerNeighbour));
76+
}
77+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 14 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,9 @@
1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

1212
import org.apache.lucene.index.FloatVectorValues;
13-
import org.apache.lucene.index.VectorSimilarityFunction;
14-
import org.apache.lucene.search.KnnCollector;
15-
import org.apache.lucene.search.ScoreDoc;
16-
import org.apache.lucene.util.Bits;
1713
import org.apache.lucene.util.FixedBitSet;
1814
import org.apache.lucene.util.VectorUtil;
19-
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
20-
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
2115
import org.apache.lucene.util.hnsw.IntToIntFunction;
22-
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
23-
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
24-
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2516
import org.elasticsearch.index.codec.vectors.SampleReader;
2617
import org.elasticsearch.simdvec.ESVectorUtil;
2718

@@ -148,40 +139,40 @@ private static int getBestCentroidFromNeighbours(
148139
NeighborHood neighborhood,
149140
float[] distances
150141
) {
151-
final int limit = neighborhood.neighbors.length - 3;
142+
final int limit = neighborhood.neighbors().length - 3;
152143
int bestCentroidOffset = centroidIdx;
153144
assert centroidIdx >= 0 && centroidIdx < centroids.length;
154145
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
155146
int i = 0;
156147
for (; i < limit; i += 4) {
157-
if (minDsq < neighborhood.maxIntraDistance) {
148+
if (minDsq < neighborhood.maxIntraDistance()) {
158149
// if the distance found is smaller than the maximum intra-cluster distance
159150
// we don't consider it for further re-assignment
160151
return bestCentroidOffset;
161152
}
162153
ESVectorUtil.squareDistanceBulk(
163154
vector,
164-
centroids[neighborhood.neighbors[i]],
165-
centroids[neighborhood.neighbors[i + 1]],
166-
centroids[neighborhood.neighbors[i + 2]],
167-
centroids[neighborhood.neighbors[i + 3]],
155+
centroids[neighborhood.neighbors()[i]],
156+
centroids[neighborhood.neighbors()[i + 1]],
157+
centroids[neighborhood.neighbors()[i + 2]],
158+
centroids[neighborhood.neighbors()[i + 3]],
168159
distances
169160
);
170161
for (int j = 0; j < distances.length; j++) {
171162
float dsq = distances[j];
172163
if (dsq < minDsq) {
173164
minDsq = dsq;
174-
bestCentroidOffset = neighborhood.neighbors[i + j];
165+
bestCentroidOffset = neighborhood.neighbors()[i + j];
175166
}
176167
}
177168
}
178-
for (; i < neighborhood.neighbors.length; i++) {
179-
if (minDsq < neighborhood.maxIntraDistance) {
169+
for (; i < neighborhood.neighbors().length; i++) {
170+
if (minDsq < neighborhood.maxIntraDistance()) {
180171
// if the distance found is smaller than the maximum intra-cluster distance
181172
// we don't consider it for further re-assignment
182173
return bestCentroidOffset;
183174
}
184-
int offset = neighborhood.neighbors[i];
175+
int offset = neighborhood.neighbors()[i];
185176
// float score = neighborhood.scores[i];
186177
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
187178
// compute the distance to the centroid
@@ -223,131 +214,12 @@ private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNe
223214
assert centers.length > clustersPerNeighborhood;
224215
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
225216
if (centers.length < 15_000) {
226-
return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
217+
return NeighborHood.computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
227218
} else {
228-
return computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
219+
return NeighborHood.computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
229220
}
230221
}
231222

232-
static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
233-
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
234-
int scoringOrdinal;
235-
236-
@Override
237-
public float score(int node) {
238-
return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]);
239-
}
240-
241-
@Override
242-
public int maxOrd() {
243-
return centers.length;
244-
}
245-
246-
@Override
247-
public void setScoringOrdinal(int node) {
248-
scoringOrdinal = node;
249-
}
250-
};
251-
final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() {
252-
@Override
253-
public UpdateableRandomVectorScorer scorer() {
254-
return scorer;
255-
}
256-
257-
@Override
258-
public RandomVectorScorerSupplier copy() {
259-
return this;
260-
}
261-
};
262-
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length);
263-
final NeighborHood[] neighborhoods = new NeighborHood[centers.length];
264-
final SingleBit singleBit = new SingleBit(centers.length);
265-
for (int i = 0; i < centers.length; i++) {
266-
scorer.setScoringOrdinal(i);
267-
singleBit.indexSet = i;
268-
final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE);
269-
final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
270-
if (scoreDocs.length == 0) {
271-
// no neighbors, skip
272-
neighborhoods[i] = NeighborHood.EMPTY;
273-
continue;
274-
}
275-
final int[] neighbors = new int[scoreDocs.length];
276-
for (int j = 0; j < neighbors.length; j++) {
277-
neighbors[j] = scoreDocs[j].doc;
278-
assert neighbors[j] != i;
279-
}
280-
final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1;
281-
neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity);
282-
}
283-
return neighborhoods;
284-
}
285-
286-
private static class SingleBit implements Bits {
287-
288-
private final int length;
289-
private int indexSet;
290-
291-
SingleBit(int length) {
292-
this.length = length;
293-
}
294-
295-
@Override
296-
public boolean get(int index) {
297-
return index != indexSet;
298-
}
299-
300-
@Override
301-
public int length() {
302-
return length;
303-
}
304-
}
305-
306-
static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
307-
int k = centers.length;
308-
NeighborQueue[] neighborQueues = new NeighborQueue[k];
309-
for (int i = 0; i < k; i++) {
310-
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
311-
}
312-
final float[] scores = new float[4];
313-
final int limit = k - 3;
314-
for (int i = 0; i < k - 1; i++) {
315-
float[] center = centers[i];
316-
int j = i + 1;
317-
for (; j < limit; j += 4) {
318-
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
319-
for (int h = 0; h < 4; h++) {
320-
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
321-
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
322-
}
323-
}
324-
for (; j < k; j++) {
325-
float dsq = VectorUtil.squareDistance(center, centers[j]);
326-
neighborQueues[j].insertWithOverflow(i, dsq);
327-
neighborQueues[i].insertWithOverflow(j, dsq);
328-
}
329-
}
330-
331-
NeighborHood[] neighborhoods = new NeighborHood[k];
332-
for (int i = 0; i < k; i++) {
333-
NeighborQueue queue = neighborQueues[i];
334-
if (queue.size() == 0) {
335-
// no neighbors, skip
336-
neighborhoods[i] = NeighborHood.EMPTY;
337-
continue;
338-
}
339-
// consume the queue into the neighbors array and get the maximum intra-cluster distance
340-
int[] neighbors = new int[queue.size()];
341-
float maxIntraDistance = queue.topScore();
342-
int iter = 0;
343-
while (queue.size() > 0) {
344-
neighbors[neighbors.length - ++iter] = queue.pop();
345-
}
346-
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
347-
}
348-
return neighborhoods;
349-
}
350-
351223
private void assignSpilled(
352224
FloatVectorValues vectors,
353225
KMeansIntermediate kmeansIntermediate,
@@ -391,8 +263,8 @@ private void assignSpilled(
391263
if (neighborhoods != null) {
392264
assert neighborhoods[currAssignment] != null;
393265
NeighborHood neighborhood = neighborhoods[currAssignment];
394-
centroidCount = neighborhood.neighbors.length;
395-
centroidOrds = c -> neighborhood.neighbors[c];
266+
centroidCount = neighborhood.neighbors().length;
267+
centroidOrds = c -> neighborhood.neighbors()[c];
396268
} else {
397269
centroidCount = centroids.length - 1;
398270
centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid
@@ -436,10 +308,6 @@ private void assignSpilled(
436308
}
437309
}
438310

439-
record NeighborHood(int[] neighbors, float maxIntraDistance) {
440-
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
441-
}
442-
443311
/**
444312
* cluster using a lloyd k-means algorithm that is not neighbor aware
445313
*

0 commit comments

Comments
 (0)