|
36 | 36 | import java.util.concurrent.ForkJoinPool; |
37 | 37 | import java.util.concurrent.ThreadLocalRandom; |
38 | 38 | import java.util.concurrent.atomic.AtomicReference; |
| 39 | +import java.util.logging.Logger; |
39 | 40 | import java.util.stream.Collectors; |
40 | 41 | import java.util.stream.IntStream; |
41 | 42 |
|
|
53 | 54 | public class ProductQuantization implements VectorCompressor<ByteSequence<?>>, Accountable { |
54 | 55 | private static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination |
55 | 56 |
|
| 57 | + protected static final Logger LOG = Logger.getLogger(ProductQuantization.class.getName()); |
| 58 | + |
56 | 59 | private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); |
57 | 60 | static final int DEFAULT_CLUSTERS = 256; // number of clusters per subspace = one byte's worth |
58 | 61 | static final int K_MEANS_ITERATIONS = 6; |
@@ -114,6 +117,8 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv, |
114 | 117 | ForkJoinPool simdExecutor, |
115 | 118 | ForkJoinPool parallelExecutor) |
116 | 119 | { |
| 120 | + checkClusterCount(clusterCount); |
| 121 | + |
117 | 122 | var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M); |
118 | 123 | var vectors = extractTrainingVectors(ravv, parallelExecutor); |
119 | 124 |
|
@@ -190,6 +195,8 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, |
190 | 195 | } |
191 | 196 |
|
192 | 197 | ProductQuantization(VectorFloat<?>[] codebooks, int clusterCount, int[][] subvectorSizesAndOffsets, VectorFloat<?> globalCentroid, float anisotropicThreshold) { |
| 198 | + checkClusterCount(clusterCount); |
| 199 | + |
193 | 200 | this.codebooks = codebooks; |
194 | 201 | this.globalCentroid = globalCentroid; |
195 | 202 | this.M = codebooks.length; |
@@ -718,4 +725,13 @@ public String toString() { |
718 | 725 | anisotropicThreshold, |
719 | 726 | KMeansPlusPlusClusterer.computeParallelCostMultiplier(anisotropicThreshold, originalDimension)); |
720 | 727 | } |
| 728 | + |
| 729 | + private static void checkClusterCount(int clusterCount) { |
| 730 | + if (clusterCount > 256) { |
| 731 | + throw new IllegalArgumentException("Too many PQ clusters: " + clusterCount + " > 256"); |
| 732 | + } |
| 733 | + if (clusterCount < 256) { |
| 734 | + LOG.warning("Using less than 256 PQ clusters will not reduce the memory footprint."); |
| 735 | + } |
| 736 | + } |
721 | 737 | } |
0 commit comments