- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.6k
[DiskBBQ] Use a HNSW graph to compute neighbours #134109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
ab150c8
              31bac6b
              53f4bd6
              5034bca
              b153b8e
              651e3bf
              2264a46
              d4eea58
              b9cdea0
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -10,9 +10,18 @@ | |
| package org.elasticsearch.index.codec.vectors.cluster; | ||
|  | ||
| import org.apache.lucene.index.FloatVectorValues; | ||
| import org.apache.lucene.index.VectorSimilarityFunction; | ||
| import org.apache.lucene.search.KnnCollector; | ||
| import org.apache.lucene.search.ScoreDoc; | ||
| import org.apache.lucene.util.Bits; | ||
| import org.apache.lucene.util.FixedBitSet; | ||
| import org.apache.lucene.util.VectorUtil; | ||
| import org.apache.lucene.util.hnsw.HnswGraphBuilder; | ||
| import org.apache.lucene.util.hnsw.HnswGraphSearcher; | ||
| import org.apache.lucene.util.hnsw.IntToIntFunction; | ||
| import org.apache.lucene.util.hnsw.OnHeapHnswGraph; | ||
| import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; | ||
| import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; | ||
| import org.elasticsearch.index.codec.vectors.SampleReader; | ||
| import org.elasticsearch.simdvec.ESVectorUtil; | ||
|  | ||
|  | @@ -210,9 +219,92 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[] | |
| return bestCentroidOffset; | ||
| } | ||
|  | ||
| private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) { | ||
| private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException { | ||
| assert centers.length > clustersPerNeighborhood; | ||
| // experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up | ||
| if (centers.length < 15_000) { | ||
| return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); | ||
| } else { | ||
| return computeNeighborhoodsGraph(centers, clustersPerNeighborhood); | ||
| } | ||
| } | ||
|  | ||
| static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { | ||
| final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { | ||
| int scoringOrdinal; | ||
|  | ||
| @Override | ||
| public float score(int node) { | ||
| return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]); | ||
| } | ||
|  | ||
| @Override | ||
| public int maxOrd() { | ||
| return centers.length; | ||
| } | ||
|  | ||
| @Override | ||
| public void setScoringOrdinal(int node) { | ||
| scoringOrdinal = node; | ||
| } | ||
| }; | ||
| final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() { | ||
| @Override | ||
| public UpdateableRandomVectorScorer scorer() { | ||
| return scorer; | ||
| } | ||
|  | ||
| @Override | ||
| public RandomVectorScorerSupplier copy() { | ||
| return this; | ||
| } | ||
| }; | ||
| final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length); | ||
|          | ||
| final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; | ||
| final SingleBit singleBit = new SingleBit(centers.length); | ||
| for (int i = 0; i < centers.length; i++) { | ||
| scorer.setScoringOrdinal(i); | ||
| singleBit.indexSet = i; | ||
| final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE); | ||
|          | ||
| final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; | ||
| if (scoreDocs.length == 0) { | ||
| // no neighbors, skip | ||
| neighborhoods[i] = NeighborHood.EMPTY; | ||
| continue; | ||
| } | ||
| final int[] neighbors = new int[scoreDocs.length]; | ||
| for (int j = 0; j < neighbors.length; j++) { | ||
| neighbors[j] = scoreDocs[j].doc; | ||
| assert neighbors[j] != i; | ||
| } | ||
| final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1; | ||
| neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity); | ||
| } | ||
| return neighborhoods; | ||
| } | ||
|  | ||
| private static class SingleBit implements Bits { | ||
|  | ||
| private final int length; | ||
| private int indexSet; | ||
|  | ||
| SingleBit(int length) { | ||
| this.length = length; | ||
| } | ||
|  | ||
| @Override | ||
| public boolean get(int index) { | ||
| return index != indexSet; | ||
| } | ||
|  | ||
| @Override | ||
| public int length() { | ||
| return length; | ||
| } | ||
| } | ||
|  | ||
| static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { | ||
| int k = centers.length; | ||
| assert k > clustersPerNeighborhood; | ||
| NeighborQueue[] neighborQueues = new NeighborQueue[k]; | ||
| for (int i = 0; i < k; i++) { | ||
| neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can optimise the graph to work better for lower scale but this is good as a first threshold. That's for segments greater than 1M with 64 vectors per centroid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reducing the number of connections could make this threshold smaller.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I didn't spend too much because time it seems pretty fast for low values (few seconds) so I wonder if there is need to optimize those cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just picking something "good enough" is alright. It provides a nice improvement and any optimizations we make won't be "format breaking" :)