Skip to content

Commit 0dc4392

Browse files
authored
[DiskBBQ] Implement Bulk scoring when computing neighbours using a Hnsw graph (#135278)
1 parent ddcdf29 commit 0dc4392

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborHood.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.index.codec.vectors.cluster;
1111

12-
import org.apache.lucene.index.VectorSimilarityFunction;
1312
import org.apache.lucene.search.KnnCollector;
1413
import org.apache.lucene.search.TopDocs;
1514
import org.apache.lucene.search.knn.KnnSearchStrategy;
@@ -32,8 +31,8 @@ public record NeighborHood(int[] neighbors, float maxIntraDistance) {
3231

3332
public static NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
3433
assert centers.length > clustersPerNeighborhood;
35-
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
36-
if (centers.length < 15_000) {
34+
// experiments shows that below 10k, we better use brute force, otherwise hnsw gives us a nice speed up
35+
if (centers.length < 10_000) {
3736
return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
3837
} else {
3938
return computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
@@ -88,10 +87,33 @@ public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, i
8887
public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
8988
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
9089
int scoringOrdinal;
90+
private final float[] distances = new float[4];
9191

9292
@Override
9393
public float score(int node) {
94-
return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]);
94+
return VectorUtil.normalizeDistanceToUnitInterval(VectorUtil.squareDistance(centers[scoringOrdinal], centers[node]));
95+
}
96+
97+
@Override
98+
public void bulkScore(int[] nodes, float[] scores, int numNodes) {
99+
int i = 0;
100+
final int limit = numNodes - 3;
101+
for (; i < limit; i += 4) {
102+
ESVectorUtil.squareDistanceBulk(
103+
centers[scoringOrdinal],
104+
centers[nodes[i]],
105+
centers[nodes[i + 1]],
106+
centers[nodes[i + 2]],
107+
centers[nodes[i + 3]],
108+
distances
109+
);
110+
for (int j = 0; j < 4; j++) {
111+
scores[i + j] = VectorUtil.normalizeDistanceToUnitInterval(distances[j]);
112+
}
113+
}
114+
for (; i < numNodes; i++) {
115+
scores[i] = score(nodes[i]);
116+
}
95117
}
96118

97119
@Override

0 commit comments

Comments
 (0)