|
34 | 34 | import java.util.Arrays; |
35 | 35 | import java.util.List; |
36 | 36 | import java.util.Objects; |
| 37 | +import java.util.concurrent.Callable; |
37 | 38 | import java.util.concurrent.ForkJoinPool; |
38 | 39 | import java.util.concurrent.ThreadLocalRandom; |
39 | 40 | import java.util.concurrent.atomic.AtomicReference; |
@@ -183,12 +184,13 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, |
183 | 184 | } |
184 | 185 | var vectors = vectorsMutable; // "effectively final" to make the closure happy |
185 | 186 |
|
186 | | - var refinedCodebooks = simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { |
| 187 | + Callable<VectorFloat<?>[]> callable = () -> IntStream.range(0, M).parallel().mapToObj(m -> { |
187 | 188 | VectorFloat<?>[] subvectors = extractSubvectors(vectors, m, subvectorSizesAndOffsets); |
188 | 189 | var clusterer = new KMeansPlusPlusClusterer(subvectors, codebooks[m], anisotropicThreshold); |
189 | 190 | return clusterer.cluster(anisotropicThreshold == UNWEIGHTED ? lloydsRounds : 0, |
190 | 191 | anisotropicThreshold == UNWEIGHTED ? 0 : lloydsRounds); |
191 | | - }).toArray(VectorFloat<?>[]::new)).join(); |
| 192 | + }).toArray(VectorFloat<?>[]::new); |
| 193 | + var refinedCodebooks = simdExecutor.submit(callable).join(); |
192 | 194 |
|
193 | 195 | return new ProductQuantization(refinedCodebooks, clusterCount, subvectorSizesAndOffsets, globalCentroid, anisotropicThreshold); |
194 | 196 | } |
@@ -459,11 +461,12 @@ public int getClusterCount() { |
459 | 461 |
|
460 | 462 | static VectorFloat<?>[] createCodebooks(List<VectorFloat<?>> vectors, int[][] subvectorSizeAndOffset, int clusters, float anisotropicThreshold, ForkJoinPool simdExecutor) { |
461 | 463 | int M = subvectorSizeAndOffset.length; |
462 | | - return simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { |
| 464 | + Callable<VectorFloat<?>[]> callable = () -> IntStream.range(0, M).parallel().mapToObj(m -> { |
463 | 465 | VectorFloat<?>[] subvectors = extractSubvectors(vectors, m, subvectorSizeAndOffset); |
464 | 466 | var clusterer = new KMeansPlusPlusClusterer(subvectors, clusters, anisotropicThreshold); |
465 | 467 | return clusterer.cluster(K_MEANS_ITERATIONS, anisotropicThreshold == UNWEIGHTED ? 0 : K_MEANS_ITERATIONS); |
466 | | - }).toArray(VectorFloat<?>[]::new)).join(); |
| 468 | + }).toArray(VectorFloat<?>[]::new); |
| 469 | + return simdExecutor.submit(callable).join(); |
467 | 470 | } |
468 | 471 |
|
469 | 472 | /** |
|
0 commit comments