- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.6k
[DiskBBQ] Use a HNSW graph to compute neighbours #134109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from 6 commits
      Commits
    
    
            Show all changes
          
          
            9 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      ab150c8
              
                [DiskBBQ] Use a hnsw graph to compute neighbours
              
              
                iverase 31bac6b
              
                iter
              
              
                iverase 53f4bd6
              
                iter
              
              
                iverase 5034bca
              
                add benchmark
              
              
                iverase b153b8e
              
                Merge branch 'main' into computeNeighbours
              
              
                iverase 651e3bf
              
                fix test
              
              
                iverase 2264a46
              
                address review comments
              
              
                iverase d4eea58
              
                Merge branch 'main' into computeNeighbours
              
              
                iverase b9cdea0
              
                Merge branch 'main' into computeNeighbours
              
              
                iverase File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
        
          
          
            77 changes: 77 additions & 0 deletions
          
          77 
        
  benchmarks/src/main/java/org/elasticsearch/benchmark/vector/ComputeNeighboursBenchmark.java
  
  
      
      
   
        
      
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the "Elastic License | ||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
| * Public License v 1"; you may not use this file except in compliance with, at | ||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||
| */ | ||
|  | ||
| package org.elasticsearch.benchmark.vector; | ||
|  | ||
| import org.elasticsearch.common.logging.LogConfigurator; | ||
| import org.elasticsearch.index.codec.vectors.cluster.NeighborHood; | ||
| import org.openjdk.jmh.annotations.Benchmark; | ||
| import org.openjdk.jmh.annotations.BenchmarkMode; | ||
| import org.openjdk.jmh.annotations.Fork; | ||
| import org.openjdk.jmh.annotations.Measurement; | ||
| import org.openjdk.jmh.annotations.Mode; | ||
| import org.openjdk.jmh.annotations.OutputTimeUnit; | ||
| import org.openjdk.jmh.annotations.Param; | ||
| import org.openjdk.jmh.annotations.Scope; | ||
| import org.openjdk.jmh.annotations.Setup; | ||
| import org.openjdk.jmh.annotations.State; | ||
| import org.openjdk.jmh.annotations.Warmup; | ||
| import org.openjdk.jmh.infra.Blackhole; | ||
|  | ||
| import java.io.IOException; | ||
| import java.util.Random; | ||
| import java.util.concurrent.TimeUnit; | ||
|  | ||
| @BenchmarkMode(Mode.AverageTime) | ||
| @OutputTimeUnit(TimeUnit.SECONDS) | ||
| @State(Scope.Benchmark) | ||
| // first iteration is complete garbage, so make sure we really warmup | ||
| @Warmup(iterations = 1, time = 1) | ||
| // real iterations. not useful to spend tons of time here, better to fork more | ||
| @Measurement(iterations = 3, time = 1) | ||
| // engage some noise reduction | ||
| @Fork(value = 1) | ||
| public class ComputeNeighboursBenchmark { | ||
|  | ||
| static { | ||
| LogConfigurator.configureESLogging(); // native access requires logging to be initialized | ||
| } | ||
|  | ||
| @Param({ "1000", "2000", "3000", "5000", "10000", "20000", "50000" }) | ||
| int numVectors; | ||
|  | ||
| @Param({ "384", "782", "1024" }) | ||
| int dims; | ||
|  | ||
| float[][] vectors; | ||
| int clusterPerNeighbour = 128; | ||
|  | ||
| @Setup | ||
| public void setup() throws IOException { | ||
| Random random = new Random(123); | ||
| vectors = new float[numVectors][dims]; | ||
| for (float[] vector : vectors) { | ||
| for (int i = 0; i < dims; i++) { | ||
| vector[i] = random.nextFloat(); | ||
| } | ||
| } | ||
| } | ||
|  | ||
| @Benchmark | ||
| @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) | ||
| public void bruteForce(Blackhole bh) { | ||
| bh.consume(NeighborHood.computeNeighborhoodsBruteForce(vectors, clusterPerNeighbour)); | ||
| } | ||
|  | ||
| @Benchmark | ||
| @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) | ||
| public void graph(Blackhole bh) throws IOException { | ||
| bh.consume(NeighborHood.computeNeighborhoodsGraph(vectors, clusterPerNeighbour)); | ||
| } | ||
| } | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
        
          
          
            148 changes: 148 additions & 0 deletions
          
          148 
        
  server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java
  
  
      
      
   
        
      
      
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the "Elastic License | ||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
| * Public License v 1"; you may not use this file except in compliance with, at | ||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||
| */ | ||
|  | ||
| package org.elasticsearch.index.codec.vectors.cluster; | ||
|  | ||
| import org.apache.lucene.index.VectorSimilarityFunction; | ||
| import org.apache.lucene.search.KnnCollector; | ||
| import org.apache.lucene.search.ScoreDoc; | ||
| import org.apache.lucene.util.Bits; | ||
| import org.apache.lucene.util.VectorUtil; | ||
| import org.apache.lucene.util.hnsw.HnswGraphBuilder; | ||
| import org.apache.lucene.util.hnsw.HnswGraphSearcher; | ||
| import org.apache.lucene.util.hnsw.OnHeapHnswGraph; | ||
| import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; | ||
| import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; | ||
| import org.elasticsearch.simdvec.ESVectorUtil; | ||
|  | ||
| import java.io.IOException; | ||
|  | ||
| public record NeighborHood(int[] neighbors, float maxIntraDistance) { | ||
|  | ||
| static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); | ||
|  | ||
| public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { | ||
| final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { | ||
| int scoringOrdinal; | ||
|  | ||
| @Override | ||
| public float score(int node) { | ||
| return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]); | ||
| } | ||
|  | ||
| @Override | ||
| public int maxOrd() { | ||
| return centers.length; | ||
| } | ||
|  | ||
| @Override | ||
| public void setScoringOrdinal(int node) { | ||
| scoringOrdinal = node; | ||
| } | ||
| }; | ||
| final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() { | ||
| @Override | ||
| public UpdateableRandomVectorScorer scorer() { | ||
| return scorer; | ||
| } | ||
|  | ||
| @Override | ||
| public RandomVectorScorerSupplier copy() { | ||
| return this; | ||
| } | ||
| }; | ||
| final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length); | ||
| final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; | ||
| final SingleBit singleBit = new SingleBit(centers.length); | ||
|         
                  iverase marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| for (int i = 0; i < centers.length; i++) { | ||
| scorer.setScoringOrdinal(i); | ||
| singleBit.indexSet = i; | ||
| final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE); | ||
|         
                  iverase marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; | ||
|         
                  iverase marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| if (scoreDocs.length == 0) { | ||
| // no neighbors, skip | ||
| neighborhoods[i] = NeighborHood.EMPTY; | ||
| continue; | ||
| } | ||
| final int[] neighbors = new int[scoreDocs.length]; | ||
| for (int j = 0; j < neighbors.length; j++) { | ||
| neighbors[j] = scoreDocs[j].doc; | ||
| assert neighbors[j] != i; | ||
| } | ||
| final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1; | ||
| neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity); | ||
| } | ||
| return neighborhoods; | ||
| } | ||
|  | ||
| private static class SingleBit implements Bits { | ||
|  | ||
| private final int length; | ||
| private int indexSet; | ||
|  | ||
| SingleBit(int length) { | ||
| this.length = length; | ||
| } | ||
|  | ||
| @Override | ||
| public boolean get(int index) { | ||
| return index != indexSet; | ||
| } | ||
|  | ||
| @Override | ||
| public int length() { | ||
| return length; | ||
| } | ||
| } | ||
|  | ||
| public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { | ||
| int k = centers.length; | ||
| NeighborQueue[] neighborQueues = new NeighborQueue[k]; | ||
| for (int i = 0; i < k; i++) { | ||
| neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); | ||
| } | ||
| final float[] scores = new float[4]; | ||
| final int limit = k - 3; | ||
| for (int i = 0; i < k - 1; i++) { | ||
| float[] center = centers[i]; | ||
| int j = i + 1; | ||
| for (; j < limit; j += 4) { | ||
| ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); | ||
| for (int h = 0; h < 4; h++) { | ||
| neighborQueues[j + h].insertWithOverflow(i, scores[h]); | ||
| neighborQueues[i].insertWithOverflow(j + h, scores[h]); | ||
| } | ||
| } | ||
| for (; j < k; j++) { | ||
| float dsq = VectorUtil.squareDistance(center, centers[j]); | ||
| neighborQueues[j].insertWithOverflow(i, dsq); | ||
| neighborQueues[i].insertWithOverflow(j, dsq); | ||
| } | ||
| } | ||
|  | ||
| NeighborHood[] neighborhoods = new NeighborHood[k]; | ||
| for (int i = 0; i < k; i++) { | ||
| NeighborQueue queue = neighborQueues[i]; | ||
| if (queue.size() == 0) { | ||
| // no neighbors, skip | ||
| neighborhoods[i] = NeighborHood.EMPTY; | ||
| continue; | ||
| } | ||
| // consume the queue into the neighbors array and get the maximum intra-cluster distance | ||
| int[] neighbors = new int[queue.size()]; | ||
| float maxIntraDistance = queue.topScore(); | ||
| int iter = 0; | ||
| while (queue.size() > 0) { | ||
| neighbors[neighbors.length - ++iter] = queue.pop(); | ||
| } | ||
| neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); | ||
| } | ||
| return neighborhoods; | ||
| } | ||
| } | ||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can optimise the graph to work better for lower scale but this is good as a first threshold. That's for segments greater than 1M with 64 vectors per centroid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reducing the number of connections could make this threshold smaller.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I didn't spend too much because time it seems pretty fast for low values (few seconds) so I wonder if there is need to optimize those cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just picking something "good enough" is alright. It provides a nice improvement and any optimizations we make won't be "format breaking" :)