99
1010package org .elasticsearch .index .codec .vectors .cluster ;
1111
12- import org .apache .lucene .index .VectorSimilarityFunction ;
1312import org .apache .lucene .search .KnnCollector ;
1413import org .apache .lucene .search .TopDocs ;
1514import org .apache .lucene .search .knn .KnnSearchStrategy ;
@@ -32,8 +31,8 @@ public record NeighborHood(int[] neighbors, float maxIntraDistance) {
3231
3332 public static NeighborHood [] computeNeighborhoods (float [][] centers , int clustersPerNeighborhood ) throws IOException {
3433 assert centers .length > clustersPerNeighborhood ;
35- // experiments shows that below 15k , we better use brute force, otherwise hnsw gives us a nice speed up
36- if (centers .length < 15_000 ) {
34+ // experiments shows that below 10k , we better use brute force, otherwise hnsw gives us a nice speed up
35+ if (centers .length < 10_000 ) {
3736 return computeNeighborhoodsBruteForce (centers , clustersPerNeighborhood );
3837 } else {
3938 return computeNeighborhoodsGraph (centers , clustersPerNeighborhood );
@@ -88,10 +87,33 @@ public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, i
8887 public static NeighborHood [] computeNeighborhoodsGraph (float [][] centers , int clustersPerNeighborhood ) throws IOException {
8988 final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer () {
9089 int scoringOrdinal ;
90+ private final float [] distances = new float [4 ];
9191
9292 @ Override
9393 public float score (int node ) {
94- return VectorSimilarityFunction .EUCLIDEAN .compare (centers [scoringOrdinal ], centers [node ]);
94+ return VectorUtil .normalizeDistanceToUnitInterval (VectorUtil .squareDistance (centers [scoringOrdinal ], centers [node ]));
95+ }
96+
97+ @ Override
98+ public void bulkScore (int [] nodes , float [] scores , int numNodes ) {
99+ int i = 0 ;
100+ final int limit = numNodes - 3 ;
101+ for (; i < limit ; i += 4 ) {
102+ ESVectorUtil .squareDistanceBulk (
103+ centers [scoringOrdinal ],
104+ centers [nodes [i ]],
105+ centers [nodes [i + 1 ]],
106+ centers [nodes [i + 2 ]],
107+ centers [nodes [i + 3 ]],
108+ distances
109+ );
110+ for (int j = 0 ; j < 4 ; j ++) {
111+ scores [i + j ] = VectorUtil .normalizeDistanceToUnitInterval (distances [j ]);
112+ }
113+ }
114+ for (; i < numNodes ; i ++) {
115+ scores [i ] = score (nodes [i ]);
116+ }
95117 }
96118
97119 @ Override
0 commit comments