diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/ComputeNeighboursBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/ComputeNeighboursBenchmark.java new file mode 100644 index 0000000000000..8d711e8e69da0 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/ComputeNeighboursBenchmark.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.cluster.NeighborHood; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.SECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 1, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 3, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class ComputeNeighboursBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "1000", "2000", "3000", "5000", "10000", "20000", "50000" }) + int numVectors; + + @Param({ "384", "782", "1024" }) + int dims; + + float[][] vectors; + int clusterPerNeighbour = 128; + + @Setup + public void setup() throws IOException { + Random random = new Random(123); + vectors = new float[numVectors][dims]; + for (float[] vector : vectors) { + for (int i = 0; i < dims; i++) { + vector[i] = random.nextFloat(); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void bruteForce(Blackhole bh) { + bh.consume(NeighborHood.computeNeighborhoodsBruteForce(vectors, clusterPerNeighbour)); + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void graph(Blackhole bh) throws IOException { + bh.consume(NeighborHood.computeNeighborhoodsGraph(vectors, clusterPerNeighbour)); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 9e83accef1268..0e5ce3d211737 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -139,40 +139,40 @@ private static int getBestCentroidFromNeighbours( NeighborHood neighborhood, float[] distances ) { - final int limit = neighborhood.neighbors.length - 3; + final int limit = neighborhood.neighbors().length - 3; int bestCentroidOffset = centroidIdx; assert centroidIdx >= 0 && centroidIdx < centroids.length; float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); int i = 0; for (; i < limit; i += 4) { - if (minDsq < neighborhood.maxIntraDistance) { + if (minDsq < neighborhood.maxIntraDistance()) { // if the distance found is smaller than the maximum intra-cluster distance // we don't consider it for further re-assignment return bestCentroidOffset; } ESVectorUtil.squareDistanceBulk( vector, - centroids[neighborhood.neighbors[i]], - centroids[neighborhood.neighbors[i + 1]], - centroids[neighborhood.neighbors[i + 2]], - centroids[neighborhood.neighbors[i + 3]], + centroids[neighborhood.neighbors()[i]], + centroids[neighborhood.neighbors()[i + 1]], + centroids[neighborhood.neighbors()[i + 2]], + centroids[neighborhood.neighbors()[i + 3]], distances ); for (int j = 0; j < distances.length; j++) { float dsq = distances[j]; if (dsq < minDsq) { minDsq = dsq; - bestCentroidOffset = neighborhood.neighbors[i + j]; + bestCentroidOffset = neighborhood.neighbors()[i + j]; } } } - for (; i < neighborhood.neighbors.length; i++) { - if (minDsq < neighborhood.maxIntraDistance) { + for (; i < neighborhood.neighbors().length; i++) { + if (minDsq < neighborhood.maxIntraDistance()) { // if the distance found is smaller than the maximum intra-cluster distance // we don't consider it for further re-assignment return bestCentroidOffset; } - int offset = neighborhood.neighbors[i]; + int offset = neighborhood.neighbors()[i]; // float score = neighborhood.scores[i]; assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset; // compute the distance to the centroid @@ -210,52 +210,6 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[] return bestCentroidOffset; } - private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) { - int k = centers.length; - assert k > clustersPerNeighborhood; - NeighborQueue[] neighborQueues = new NeighborQueue[k]; - for (int i = 0; i < k; i++) { - neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); - } - final float[] scores = new float[4]; - final int limit = k - 3; - for (int i = 0; i < k - 1; i++) { - float[] center = centers[i]; - int j = i + 1; - for (; j < limit; j += 4) { - ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); - for (int h = 0; h < 4; h++) { - neighborQueues[j + h].insertWithOverflow(i, scores[h]); - neighborQueues[i].insertWithOverflow(j + h, scores[h]); - } - } - for (; j < k; j++) { - float dsq = VectorUtil.squareDistance(center, centers[j]); - neighborQueues[j].insertWithOverflow(i, dsq); - neighborQueues[i].insertWithOverflow(j, dsq); - } - } - - NeighborHood[] neighborhoods = new NeighborHood[k]; - for (int i = 0; i < k; i++) { - NeighborQueue queue = neighborQueues[i]; - if (queue.size() == 0) { - // no neighbors, skip - neighborhoods[i] = NeighborHood.EMPTY; - continue; - } - // consume the queue into the neighbors array and get the maximum intra-cluster distance - int[] neighbors = new int[queue.size()]; - float maxIntraDistance = queue.topScore(); - int iter = 0; - while (queue.size() > 0) { - neighbors[neighbors.length - ++iter] = queue.pop(); - } - neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); - } - return neighborhoods; - } - private void assignSpilled( FloatVectorValues vectors, KMeansIntermediate kmeansIntermediate, @@ -299,8 +253,8 @@ private void assignSpilled( if (neighborhoods != null) { assert neighborhoods[currAssignment] != null; NeighborHood neighborhood = neighborhoods[currAssignment]; - centroidCount = neighborhood.neighbors.length; - centroidOrds = c -> neighborhood.neighbors[c]; + centroidCount = neighborhood.neighbors().length; + centroidOrds = c -> neighborhood.neighbors()[c]; } else { centroidCount = centroids.length - 1; centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid @@ -344,10 +298,6 @@ private void assignSpilled( } } - record NeighborHood(int[] neighbors, float maxIntraDistance) { - static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); - } - /** * cluster using a lloyd k-means algorithm that is not neighbor aware * @@ -390,7 +340,7 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter NeighborHood[] neighborhoods = null; // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering if (neighborAware && centroids.length > clustersPerNeighborhood) { - neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood); + neighborhoods = NeighborHood.computeNeighborhoods(centroids, clustersPerNeighborhood); } cluster(vectors, kMeansIntermediate, neighborhoods); if (neighborAware && soarLambda >= 0) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java new file mode 100644 index 0000000000000..5a4397b11c833 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OnHeapHnswGraph; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +public record NeighborHood(int[] neighbors, float maxIntraDistance) { + + private static final int M = 8; + private static final int EF_CONSTRUCTION = 150; + + static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); + + public static NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException { + assert centers.length > clustersPerNeighborhood; + // experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up + if (centers.length < 15_000) { + return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); + } else { + return computeNeighborhoodsGraph(centers, clustersPerNeighborhood); + } + } + + public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { + int k = centers.length; + NeighborQueue[] neighborQueues = new NeighborQueue[k]; + for (int i = 0; i < k; i++) { + neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); + } + final float[] scores = new float[4]; + final int limit = k - 3; + for (int i = 0; i < k - 1; i++) { + float[] center = centers[i]; + int j = i + 1; + for (; j < limit; j += 4) { + ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); + for (int h = 0; h < 4; h++) { + neighborQueues[j + h].insertWithOverflow(i, scores[h]); + neighborQueues[i].insertWithOverflow(j + h, scores[h]); + } + } + for (; j < k; j++) { + float dsq = VectorUtil.squareDistance(center, centers[j]); + neighborQueues[j].insertWithOverflow(i, dsq); + neighborQueues[i].insertWithOverflow(j, dsq); + } + } + + NeighborHood[] neighborhoods = new NeighborHood[k]; + for (int i = 0; i < k; i++) { + NeighborQueue queue = neighborQueues[i]; + if (queue.size() == 0) { + // no neighbors, skip + neighborhoods[i] = NeighborHood.EMPTY; + continue; + } + // consume the queue into the neighbors array and get the maximum intra-cluster distance + int[] neighbors = new int[queue.size()]; + float maxIntraDistance = queue.topScore(); + int iter = 0; + while (queue.size() > 0) { + neighbors[neighbors.length - ++iter] = queue.pop(); + } + neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); + } + return neighborhoods; + } + + public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { + final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { + int scoringOrdinal; + + @Override + public float score(int node) { + return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]); + } + + @Override + public int maxOrd() { + return centers.length; + } + + @Override + public void setScoringOrdinal(int node) { + scoringOrdinal = node; + } + }; + final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() { + @Override + public UpdateableRandomVectorScorer scorer() { + return scorer; + } + + @Override + public RandomVectorScorerSupplier copy() { + return this; + } + }; + final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, M, EF_CONSTRUCTION, 42L).build(centers.length); + final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; + // oversample the number of neighbors we collect to improve recall + final ReusableKnnCollector collector = new ReusableKnnCollector(2 * clustersPerNeighborhood); + for (int i = 0; i < centers.length; i++) { + collector.reset(i); + scorer.setScoringOrdinal(i); + HnswGraphSearcher.search(scorer, collector, graph, null); + NeighborQueue queue = collector.queue; + if (queue.size() == 0) { + // no neighbors, skip + neighborhoods[i] = NeighborHood.EMPTY; + continue; + } + while (queue.size() > clustersPerNeighborhood) { + queue.pop(); + } + final float minScore = queue.topScore(); + final int[] neighbors = new int[queue.size()]; + for (int j = 1; j <= neighbors.length; j++) { + neighbors[neighbors.length - j] = queue.pop(); + } + neighborhoods[i] = new NeighborHood(neighbors, (1f / minScore) - 1); + } + return neighborhoods; + } + + private static class ReusableKnnCollector implements KnnCollector { + + private final NeighborQueue queue; + private final int k; + int visitedCount; + int currenOrd; + + ReusableKnnCollector(int k) { + this.k = k; + this.queue = new NeighborQueue(k, false); + } + + void reset(int ord) { + queue.clear(); + visitedCount = 0; + currenOrd = ord; + } + + @Override + public boolean earlyTerminated() { + return false; + } + + @Override + public void incVisitedCount(int count) { + visitedCount += count; + } + + @Override + public long visitedCount() { + return visitedCount; + } + + @Override + public long visitLimit() { + return Integer.MAX_VALUE; + } + + @Override + public int k() { + return k; + } + + @Override + public boolean collect(int docId, float similarity) { + if (currenOrd != docId) { + return queue.insertWithOverflow(docId, similarity); + } + return false; + } + + @Override + public float minCompetitiveSimilarity() { + return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY; + } + + @Override + public TopDocs topDocs() { + throw new UnsupportedOperationException(); + } + + @Override + public KnnSearchStrategy getSearchStrategy() { + return null; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java index a2d34d28f3784..c69a03ca90bc6 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java @@ -15,9 +15,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; public class KMeansLocalTests extends ESTestCase { @@ -141,4 +144,47 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus } return FloatVectorValues.fromFloats(vectors, nDims); } + + public void testComputeNeighbours() throws IOException { + int numCentroids = randomIntBetween(1000, 2000); + int dims = randomIntBetween(10, 200); + float[][] vectors = new float[numCentroids][dims]; + for (int i = 0; i < numCentroids; i++) { + for (int j = 0; j < dims; j++) { + vectors[i][j] = randomFloat(); + } + } + int clustersPerNeighbour = randomIntBetween(32, 128); + NeighborHood[] neighborHoodsGraph = NeighborHood.computeNeighborhoodsGraph(vectors, clustersPerNeighbour); + NeighborHood[] neighborHoodsBruteForce = NeighborHood.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour); + assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length); + for (int i = 0; i < neighborHoodsGraph.length; i++) { + assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length); + int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors()); + double recall = (double) matched / neighborHoodsGraph[i].neighbors().length; + assertThat(recall, greaterThanOrEqualTo(0.5)); + if (recall == 1.0) { + // we cannot assert on array equality as there can be small differences due to numerical errors + assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f); + } + } + } + + private static int compareNN(int currentId, int[] expected, int[] results) { + int matched = 0; + Set expectedSet = new HashSet<>(); + Set alreadySeen = new HashSet<>(); + for (int i : expected) { + assertNotEquals(currentId, i); + assertTrue(expectedSet.add(i)); + } + for (int i : results) { + assertNotEquals(currentId, i); + assertTrue(alreadySeen.add(i)); + if (expectedSet.contains(i)) { + ++matched; + } + } + return matched; + } }