|
10 | 10 | package org.elasticsearch.index.codec.vectors.cluster; |
11 | 11 |
|
12 | 12 | import org.apache.lucene.index.FloatVectorValues; |
13 | | -import org.apache.lucene.index.VectorSimilarityFunction; |
14 | | -import org.apache.lucene.search.KnnCollector; |
15 | | -import org.apache.lucene.search.ScoreDoc; |
16 | | -import org.apache.lucene.util.Bits; |
17 | 13 | import org.apache.lucene.util.FixedBitSet; |
18 | 14 | import org.apache.lucene.util.VectorUtil; |
19 | | -import org.apache.lucene.util.hnsw.HnswGraphBuilder; |
20 | | -import org.apache.lucene.util.hnsw.HnswGraphSearcher; |
21 | 15 | import org.apache.lucene.util.hnsw.IntToIntFunction; |
22 | | -import org.apache.lucene.util.hnsw.OnHeapHnswGraph; |
23 | | -import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; |
24 | | -import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; |
25 | 16 | import org.elasticsearch.index.codec.vectors.SampleReader; |
26 | 17 | import org.elasticsearch.simdvec.ESVectorUtil; |
27 | 18 |
|
@@ -148,40 +139,40 @@ private static int getBestCentroidFromNeighbours( |
148 | 139 | NeighborHood neighborhood, |
149 | 140 | float[] distances |
150 | 141 | ) { |
151 | | - final int limit = neighborhood.neighbors.length - 3; |
| 142 | + final int limit = neighborhood.neighbors().length - 3; |
152 | 143 | int bestCentroidOffset = centroidIdx; |
153 | 144 | assert centroidIdx >= 0 && centroidIdx < centroids.length; |
154 | 145 | float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); |
155 | 146 | int i = 0; |
156 | 147 | for (; i < limit; i += 4) { |
157 | | - if (minDsq < neighborhood.maxIntraDistance) { |
| 148 | + if (minDsq < neighborhood.maxIntraDistance()) { |
158 | 149 | // if the distance found is smaller than the maximum intra-cluster distance |
159 | 150 | // we don't consider it for further re-assignment |
160 | 151 | return bestCentroidOffset; |
161 | 152 | } |
162 | 153 | ESVectorUtil.squareDistanceBulk( |
163 | 154 | vector, |
164 | | - centroids[neighborhood.neighbors[i]], |
165 | | - centroids[neighborhood.neighbors[i + 1]], |
166 | | - centroids[neighborhood.neighbors[i + 2]], |
167 | | - centroids[neighborhood.neighbors[i + 3]], |
| 155 | + centroids[neighborhood.neighbors()[i]], |
| 156 | + centroids[neighborhood.neighbors()[i + 1]], |
| 157 | + centroids[neighborhood.neighbors()[i + 2]], |
| 158 | + centroids[neighborhood.neighbors()[i + 3]], |
168 | 159 | distances |
169 | 160 | ); |
170 | 161 | for (int j = 0; j < distances.length; j++) { |
171 | 162 | float dsq = distances[j]; |
172 | 163 | if (dsq < minDsq) { |
173 | 164 | minDsq = dsq; |
174 | | - bestCentroidOffset = neighborhood.neighbors[i + j]; |
| 165 | + bestCentroidOffset = neighborhood.neighbors()[i + j]; |
175 | 166 | } |
176 | 167 | } |
177 | 168 | } |
178 | | - for (; i < neighborhood.neighbors.length; i++) { |
179 | | - if (minDsq < neighborhood.maxIntraDistance) { |
| 169 | + for (; i < neighborhood.neighbors().length; i++) { |
| 170 | + if (minDsq < neighborhood.maxIntraDistance()) { |
180 | 171 | // if the distance found is smaller than the maximum intra-cluster distance |
181 | 172 | // we don't consider it for further re-assignment |
182 | 173 | return bestCentroidOffset; |
183 | 174 | } |
184 | | - int offset = neighborhood.neighbors[i]; |
| 175 | + int offset = neighborhood.neighbors()[i]; |
185 | 176 | // float score = neighborhood.scores[i]; |
186 | 177 | assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset; |
187 | 178 | // compute the distance to the centroid |
@@ -223,131 +214,12 @@ private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNe |
223 | 214 | assert centers.length > clustersPerNeighborhood; |
224 | 215 | // experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up |
225 | 216 | if (centers.length < 15_000) { |
226 | | - return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); |
| 217 | + return NeighborHood.computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); |
227 | 218 | } else { |
228 | | - return computeNeighborhoodsGraph(centers, clustersPerNeighborhood); |
| 219 | + return NeighborHood.computeNeighborhoodsGraph(centers, clustersPerNeighborhood); |
229 | 220 | } |
230 | 221 | } |
231 | 222 |
|
232 | | - static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { |
233 | | - final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { |
234 | | - int scoringOrdinal; |
235 | | - |
236 | | - @Override |
237 | | - public float score(int node) { |
238 | | - return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]); |
239 | | - } |
240 | | - |
241 | | - @Override |
242 | | - public int maxOrd() { |
243 | | - return centers.length; |
244 | | - } |
245 | | - |
246 | | - @Override |
247 | | - public void setScoringOrdinal(int node) { |
248 | | - scoringOrdinal = node; |
249 | | - } |
250 | | - }; |
251 | | - final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() { |
252 | | - @Override |
253 | | - public UpdateableRandomVectorScorer scorer() { |
254 | | - return scorer; |
255 | | - } |
256 | | - |
257 | | - @Override |
258 | | - public RandomVectorScorerSupplier copy() { |
259 | | - return this; |
260 | | - } |
261 | | - }; |
262 | | - final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length); |
263 | | - final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; |
264 | | - final SingleBit singleBit = new SingleBit(centers.length); |
265 | | - for (int i = 0; i < centers.length; i++) { |
266 | | - scorer.setScoringOrdinal(i); |
267 | | - singleBit.indexSet = i; |
268 | | - final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE); |
269 | | - final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; |
270 | | - if (scoreDocs.length == 0) { |
271 | | - // no neighbors, skip |
272 | | - neighborhoods[i] = NeighborHood.EMPTY; |
273 | | - continue; |
274 | | - } |
275 | | - final int[] neighbors = new int[scoreDocs.length]; |
276 | | - for (int j = 0; j < neighbors.length; j++) { |
277 | | - neighbors[j] = scoreDocs[j].doc; |
278 | | - assert neighbors[j] != i; |
279 | | - } |
280 | | - final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1; |
281 | | - neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity); |
282 | | - } |
283 | | - return neighborhoods; |
284 | | - } |
285 | | - |
286 | | - private static class SingleBit implements Bits { |
287 | | - |
288 | | - private final int length; |
289 | | - private int indexSet; |
290 | | - |
291 | | - SingleBit(int length) { |
292 | | - this.length = length; |
293 | | - } |
294 | | - |
295 | | - @Override |
296 | | - public boolean get(int index) { |
297 | | - return index != indexSet; |
298 | | - } |
299 | | - |
300 | | - @Override |
301 | | - public int length() { |
302 | | - return length; |
303 | | - } |
304 | | - } |
305 | | - |
306 | | - static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { |
307 | | - int k = centers.length; |
308 | | - NeighborQueue[] neighborQueues = new NeighborQueue[k]; |
309 | | - for (int i = 0; i < k; i++) { |
310 | | - neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); |
311 | | - } |
312 | | - final float[] scores = new float[4]; |
313 | | - final int limit = k - 3; |
314 | | - for (int i = 0; i < k - 1; i++) { |
315 | | - float[] center = centers[i]; |
316 | | - int j = i + 1; |
317 | | - for (; j < limit; j += 4) { |
318 | | - ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); |
319 | | - for (int h = 0; h < 4; h++) { |
320 | | - neighborQueues[j + h].insertWithOverflow(i, scores[h]); |
321 | | - neighborQueues[i].insertWithOverflow(j + h, scores[h]); |
322 | | - } |
323 | | - } |
324 | | - for (; j < k; j++) { |
325 | | - float dsq = VectorUtil.squareDistance(center, centers[j]); |
326 | | - neighborQueues[j].insertWithOverflow(i, dsq); |
327 | | - neighborQueues[i].insertWithOverflow(j, dsq); |
328 | | - } |
329 | | - } |
330 | | - |
331 | | - NeighborHood[] neighborhoods = new NeighborHood[k]; |
332 | | - for (int i = 0; i < k; i++) { |
333 | | - NeighborQueue queue = neighborQueues[i]; |
334 | | - if (queue.size() == 0) { |
335 | | - // no neighbors, skip |
336 | | - neighborhoods[i] = NeighborHood.EMPTY; |
337 | | - continue; |
338 | | - } |
339 | | - // consume the queue into the neighbors array and get the maximum intra-cluster distance |
340 | | - int[] neighbors = new int[queue.size()]; |
341 | | - float maxIntraDistance = queue.topScore(); |
342 | | - int iter = 0; |
343 | | - while (queue.size() > 0) { |
344 | | - neighbors[neighbors.length - ++iter] = queue.pop(); |
345 | | - } |
346 | | - neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); |
347 | | - } |
348 | | - return neighborhoods; |
349 | | - } |
350 | | - |
351 | 223 | private void assignSpilled( |
352 | 224 | FloatVectorValues vectors, |
353 | 225 | KMeansIntermediate kmeansIntermediate, |
@@ -391,8 +263,8 @@ private void assignSpilled( |
391 | 263 | if (neighborhoods != null) { |
392 | 264 | assert neighborhoods[currAssignment] != null; |
393 | 265 | NeighborHood neighborhood = neighborhoods[currAssignment]; |
394 | | - centroidCount = neighborhood.neighbors.length; |
395 | | - centroidOrds = c -> neighborhood.neighbors[c]; |
| 266 | + centroidCount = neighborhood.neighbors().length; |
| 267 | + centroidOrds = c -> neighborhood.neighbors()[c]; |
396 | 268 | } else { |
397 | 269 | centroidCount = centroids.length - 1; |
398 | 270 | centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid |
@@ -436,10 +308,6 @@ private void assignSpilled( |
436 | 308 | } |
437 | 309 | } |
438 | 310 |
|
439 | | - record NeighborHood(int[] neighbors, float maxIntraDistance) { |
440 | | - static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); |
441 | | - } |
442 | | - |
443 | 311 | /** |
444 | 312 | * cluster using a lloyd k-means algorithm that is not neighbor aware |
445 | 313 | * |
|
0 commit comments