|
10 | 10 | package org.elasticsearch.index.codec.vectors.cluster; |
11 | 11 |
|
12 | 12 | import org.apache.lucene.index.FloatVectorValues; |
| 13 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 14 | +import org.apache.lucene.search.KnnCollector; |
| 15 | +import org.apache.lucene.search.ScoreDoc; |
| 16 | +import org.apache.lucene.util.Bits; |
13 | 17 | import org.apache.lucene.util.FixedBitSet; |
14 | 18 | import org.apache.lucene.util.VectorUtil; |
| 19 | +import org.apache.lucene.util.hnsw.HnswGraphBuilder; |
| 20 | +import org.apache.lucene.util.hnsw.HnswGraphSearcher; |
15 | 21 | import org.apache.lucene.util.hnsw.IntToIntFunction; |
| 22 | +import org.apache.lucene.util.hnsw.OnHeapHnswGraph; |
| 23 | +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; |
| 24 | +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; |
16 | 25 | import org.elasticsearch.index.codec.vectors.SampleReader; |
17 | 26 | import org.elasticsearch.simdvec.ESVectorUtil; |
18 | 27 |
|
@@ -210,9 +219,92 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[] |
210 | 219 | return bestCentroidOffset; |
211 | 220 | } |
212 | 221 |
|
213 | | - private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) { |
| 222 | + private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException { |
| 223 | + assert centers.length > clustersPerNeighborhood; |
| 224 | + // experiments shows that below 20k, we better use brute force, otherwise hnsw gives us a nice speed up |
| 225 | + if (centers.length < 20_000) { |
| 226 | + return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); |
| 227 | + } else { |
| 228 | + return computeNeighborhoodsGraph(centers, clustersPerNeighborhood); |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { |
| 233 | + final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { |
| 234 | + int scoringOrdinal; |
| 235 | + |
| 236 | + @Override |
| 237 | + public float score(int node) { |
| 238 | + return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]); |
| 239 | + } |
| 240 | + |
| 241 | + @Override |
| 242 | + public int maxOrd() { |
| 243 | + return centers.length; |
| 244 | + } |
| 245 | + |
| 246 | + @Override |
| 247 | + public void setScoringOrdinal(int node) { |
| 248 | + scoringOrdinal = node; |
| 249 | + } |
| 250 | + }; |
| 251 | + final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() { |
| 252 | + @Override |
| 253 | + public UpdateableRandomVectorScorer scorer() { |
| 254 | + return scorer; |
| 255 | + } |
| 256 | + |
| 257 | + @Override |
| 258 | + public RandomVectorScorerSupplier copy() { |
| 259 | + return this; |
| 260 | + } |
| 261 | + }; |
| 262 | + final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length); |
| 263 | + final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; |
| 264 | + final SingleBit singleBit = new SingleBit(centers.length); |
| 265 | + for (int i = 0; i < centers.length; i++) { |
| 266 | + scorer.setScoringOrdinal(i); |
| 267 | + singleBit.indexSet = i; |
| 268 | + final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE); |
| 269 | + final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; |
| 270 | + if (scoreDocs.length == 0) { |
| 271 | + // no neighbors, skip |
| 272 | + neighborhoods[i] = NeighborHood.EMPTY; |
| 273 | + continue; |
| 274 | + } |
| 275 | + final int[] neighbors = new int[scoreDocs.length]; |
| 276 | + for (int j = 0; j < neighbors.length; j++) { |
| 277 | + neighbors[j] = scoreDocs[j].doc; |
| 278 | + assert neighbors[j] != i; |
| 279 | + } |
| 280 | + final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1; |
| 281 | + neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity); |
| 282 | + } |
| 283 | + return neighborhoods; |
| 284 | + } |
| 285 | + |
| 286 | + private static class SingleBit implements Bits { |
| 287 | + |
| 288 | + private final int length; |
| 289 | + private int indexSet; |
| 290 | + |
| 291 | + SingleBit(int length) { |
| 292 | + this.length = length; |
| 293 | + } |
| 294 | + |
| 295 | + @Override |
| 296 | + public boolean get(int index) { |
| 297 | + return index != indexSet; |
| 298 | + } |
| 299 | + |
| 300 | + @Override |
| 301 | + public int length() { |
| 302 | + return length; |
| 303 | + } |
| 304 | + } |
| 305 | + |
| 306 | + static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { |
214 | 307 | int k = centers.length; |
215 | | - assert k > clustersPerNeighborhood; |
216 | 308 | NeighborQueue[] neighborQueues = new NeighborQueue[k]; |
217 | 309 | for (int i = 0; i < k; i++) { |
218 | 310 | neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); |
|
0 commit comments