Skip to content

Commit aeeffa0

Browse files
Check values of clusterCount in PQ (#464)
Addresses issue #463 by checking the value of clusterCount in ProductQuantization. Since each PQ segment is encoded using a byte, using clusterCount = 256 is the optimal choice. With this PR, we now throw an exception if clusterCount > 256 and issue a warning if clusterCount < 256.
1 parent fbd23d5 commit aeeffa0

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.concurrent.ForkJoinPool;
3737
import java.util.concurrent.ThreadLocalRandom;
3838
import java.util.concurrent.atomic.AtomicReference;
39+
import java.util.logging.Logger;
3940
import java.util.stream.Collectors;
4041
import java.util.stream.IntStream;
4142

@@ -53,6 +54,8 @@
5354
public class ProductQuantization implements VectorCompressor<ByteSequence<?>>, Accountable {
5455
private static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination
5556

57+
protected static final Logger LOG = Logger.getLogger(ProductQuantization.class.getName());
58+
5659
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
5760
static final int DEFAULT_CLUSTERS = 256; // number of clusters per subspace = one byte's worth
5861
static final int K_MEANS_ITERATIONS = 6;
@@ -114,6 +117,8 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv,
114117
ForkJoinPool simdExecutor,
115118
ForkJoinPool parallelExecutor)
116119
{
120+
checkClusterCount(clusterCount);
121+
117122
var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M);
118123
var vectors = extractTrainingVectors(ravv, parallelExecutor);
119124

@@ -190,6 +195,8 @@ public ProductQuantization refine(RandomAccessVectorValues ravv,
190195
}
191196

192197
ProductQuantization(VectorFloat<?>[] codebooks, int clusterCount, int[][] subvectorSizesAndOffsets, VectorFloat<?> globalCentroid, float anisotropicThreshold) {
198+
checkClusterCount(clusterCount);
199+
193200
this.codebooks = codebooks;
194201
this.globalCentroid = globalCentroid;
195202
this.M = codebooks.length;
@@ -718,4 +725,13 @@ public String toString() {
718725
anisotropicThreshold,
719726
KMeansPlusPlusClusterer.computeParallelCostMultiplier(anisotropicThreshold, originalDimension));
720727
}
728+
729+
private static void checkClusterCount(int clusterCount) {
730+
if (clusterCount > 256) {
731+
throw new IllegalArgumentException("Too many PQ clusters: " + clusterCount + " > 256");
732+
}
733+
if (clusterCount < 256) {
734+
LOG.warning("Using less than 256 PQ clusters will not reduce the memory footprint.");
735+
}
736+
}
721737
}

0 commit comments

Comments
 (0)