Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@
package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.SampleReader;
import org.elasticsearch.simdvec.ESVectorUtil;

Expand Down Expand Up @@ -210,9 +219,92 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
return bestCentroidOffset;
}

private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
assert centers.length > clustersPerNeighborhood;
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
if (centers.length < 15_000) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can optimise the graph to work better for lower scale but this is good as a first threshold. That's for segments greater than 1M with 64 vectors per centroid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing the number of connections could make this threshold smaller.

Copy link
Contributor Author

@iverase iverase Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I didn't spend too much because time it seems pretty fast for low values (few seconds) so I wonder if there is need to optimize those cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just picking something "good enough" is alright. It provides a nice improvement and any optimizations we make won't be "format breaking" :)

return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
} else {
return computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
}
}

static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
int scoringOrdinal;

@Override
public float score(int node) {
return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]);
}

@Override
public int maxOrd() {
return centers.length;
}

@Override
public void setScoringOrdinal(int node) {
scoringOrdinal = node;
}
};
final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() {
@Override
public UpdateableRandomVectorScorer scorer() {
return scorer;
}

@Override
public RandomVectorScorerSupplier copy() {
return this;
}
};
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth spending a bit more time on optimising this. In my testing, M=8 had the best ratio of recall/visited percentage so might be beneficial to publish your macro benchmark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactor the code in 5034bca to publish the benchmark.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that 8 with a larger beamwidth is worth trying. (8, 150)

final NeighborHood[] neighborhoods = new NeighborHood[centers.length];
final SingleBit singleBit = new SingleBit(centers.length);
for (int i = 0; i < centers.length; i++) {
scorer.setScoringOrdinal(i);
singleBit.indexSet = i;
final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test multiple sizes? I guess that the recall is important here so we should aim for a recall of 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always use 128 for clustersPerNeighborhood. While ideally recall should be close to 1, the test does not show lost of quality on the centroids.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we should oversample here, for example use 2 * clustersPerNeighborhood to make sure we always get the top clustersPerNeighborhood?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we should oversample here, for example use 2 * clustersPerNeighborhood to make sure we always get the top clustersPerNeighborhood?

Generally, your approximate measure for HNSW is efSearch, which would be in this case an oversample. I am not sure 2x is required, but possibly more than just the number of nearest neighbors we care about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need 128 no matter what number of centroids we have? Reducing this value when we have a small number of centroids could make the graph strategy applicable earlier.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it should be scaled down for a lower value when there are fewer centroids. I do not know what that value would be.

The number is coupled to the recursive cluster splits to help capture potentially mis-assigned vectors along the split edges.

final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
if (scoreDocs.length == 0) {
// no neighbors, skip
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
final int[] neighbors = new int[scoreDocs.length];
for (int j = 0; j < neighbors.length; j++) {
neighbors[j] = scoreDocs[j].doc;
assert neighbors[j] != i;
}
final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1;
neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity);
}
return neighborhoods;
}

private static class SingleBit implements Bits {

private final int length;
private int indexSet;

SingleBit(int length) {
this.length = length;
}

@Override
public boolean get(int index) {
return index != indexSet;
}

@Override
public int length() {
return length;
}
}

static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
int k = centers.length;
assert k > clustersPerNeighborhood;
NeighborQueue[] neighborQueues = new NeighborQueue[k];
for (int i = 0; i < k; i++) {
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThan;

public class KMeansLocalTests extends ESTestCase {

Expand Down Expand Up @@ -141,4 +144,47 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
}
return FloatVectorValues.fromFloats(vectors, nDims);
}

public void testComputeNeighbours() throws IOException {
int numCentroids = randomIntBetween(100, 10000);
int dims = randomIntBetween(10, 200);
float[][] vectors = new float[numCentroids][dims];
for (int i = 0; i < numCentroids; i++) {
for (int j = 0; j < dims; j++) {
vectors[i][j] = randomFloat();
}
}
int clustersPerNeighbour = randomIntBetween(6, 32);
KMeansLocal.NeighborHood[] neighborHoodsGraph = KMeansLocal.computeNeighborhoodsGraph(vectors, clustersPerNeighbour);
KMeansLocal.NeighborHood[] neighborHoodsBruteForce = KMeansLocal.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour);
assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length);
for (int i = 0; i < neighborHoodsGraph.length; i++) {
assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length);
int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors());
double recall = (double) matched / neighborHoodsGraph[i].neighbors().length;
assertThat(recall, greaterThan(0.4));
if (recall == 1.0) {
// we cannot assert on array equality as there can be small differences due to numerical errors
assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f);
}
}
}

private static int compareNN(int currentId, int[] expected, int[] results) {
int matched = 0;
Set<Integer> expectedSet = new HashSet<>();
Set<Integer> alreadySeen = new HashSet<>();
for (int i : expected) {
assertNotEquals(currentId, i);
assertTrue(expectedSet.add(i));
}
for (int i : results) {
assertNotEquals(currentId, i);
assertTrue(alreadySeen.add(i));
if (expectedSet.contains(i)) {
++matched;
}
}
return matched;
}
}