9
9
10
10
package org .elasticsearch .index .codec .vectors .cluster ;
11
11
12
- import org .apache .lucene .index .VectorSimilarityFunction ;
13
12
import org .apache .lucene .search .KnnCollector ;
14
13
import org .apache .lucene .search .TopDocs ;
15
14
import org .apache .lucene .search .knn .KnnSearchStrategy ;
@@ -32,8 +31,8 @@ public record NeighborHood(int[] neighbors, float maxIntraDistance) {
32
31
33
32
public static NeighborHood [] computeNeighborhoods (float [][] centers , int clustersPerNeighborhood ) throws IOException {
34
33
assert centers .length > clustersPerNeighborhood ;
35
- // experiments shows that below 15k , we better use brute force, otherwise hnsw gives us a nice speed up
36
- if (centers .length < 15_000 ) {
34
+ // experiments shows that below 10k , we better use brute force, otherwise hnsw gives us a nice speed up
35
+ if (centers .length < 10_000 ) {
37
36
return computeNeighborhoodsBruteForce (centers , clustersPerNeighborhood );
38
37
} else {
39
38
return computeNeighborhoodsGraph (centers , clustersPerNeighborhood );
@@ -88,10 +87,33 @@ public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, i
88
87
public static NeighborHood [] computeNeighborhoodsGraph (float [][] centers , int clustersPerNeighborhood ) throws IOException {
89
88
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer () {
90
89
int scoringOrdinal ;
90
+ private final float [] distances = new float [4 ];
91
91
92
92
@ Override
93
93
public float score (int node ) {
94
- return VectorSimilarityFunction .EUCLIDEAN .compare (centers [scoringOrdinal ], centers [node ]);
94
+ return VectorUtil .normalizeDistanceToUnitInterval (VectorUtil .squareDistance (centers [scoringOrdinal ], centers [node ]));
95
+ }
96
+
97
+ @ Override
98
+ public void bulkScore (int [] nodes , float [] scores , int numNodes ) {
99
+ int i = 0 ;
100
+ final int limit = numNodes - 3 ;
101
+ for (; i < limit ; i += 4 ) {
102
+ ESVectorUtil .squareDistanceBulk (
103
+ centers [scoringOrdinal ],
104
+ centers [nodes [i ]],
105
+ centers [nodes [i + 1 ]],
106
+ centers [nodes [i + 2 ]],
107
+ centers [nodes [i + 3 ]],
108
+ distances
109
+ );
110
+ for (int j = 0 ; j < 4 ; j ++) {
111
+ scores [i + j ] = VectorUtil .normalizeDistanceToUnitInterval (distances [j ]);
112
+ }
113
+ }
114
+ for (; i < numNodes ; i ++) {
115
+ scores [i ] = score (nodes [i ]);
116
+ }
95
117
}
96
118
97
119
@ Override
0 commit comments