diff --git a/docs/changelog/126778.yaml b/docs/changelog/126778.yaml new file mode 100644 index 0000000000000..c695e24ba3c84 --- /dev/null +++ b/docs/changelog/126778.yaml @@ -0,0 +1,5 @@ +pr: 126778 +summary: Fix bbq quantization algorithm but for differently distributed components +area: Vector Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java index d5ed38cb5a0e1..31b254b5de560 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java @@ -75,8 +75,8 @@ public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destina assert bits[i] > 0 && bits[i] <= 8; int points = (1 << bits[i]); // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds - intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max); - intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max); + intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max); + intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max); optimizeIntervals(intervalScratch, vector, norm2, points); float nSteps = ((1 << bits[i]) - 1); float a = intervalScratch[0]; @@ -128,8 +128,8 @@ public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byt vecVar /= vector.length; double vecStd = Math.sqrt(vecVar); // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds - intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max); - intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max); + intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max); + intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max); optimizeIntervals(intervalScratch, vector, norm2, points); float nSteps = ((1 << bits) - 1); // Now we have the optimized intervals, quantize the vector diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java index e3e2d6caafe0e..73405ecc6d4fb 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java @@ -19,6 +19,62 @@ public class OptimizedScalarQuantizerTests extends ESTestCase { static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) { + float[] dequantized = new float[quantized.length]; + float a = interval[0]; + float b = interval[1]; + int nSteps = (1 << bits) - 1; + double step = (b - a) / nSteps; + for (int h = 0; h < quantized.length; h++) { + double xi = (double) (quantized[h] & 0xFF) * step + a; + dequantized[h] = (float) (xi + centroid[h]); + } + return dequantized; + } + + public void testQuantizationQuality() { + int dims = 16; + int numVectors = 32; + float[][] vectors = new float[numVectors][]; + float[] centroid = new float[dims]; + for (int i = 0; i < numVectors; ++i) { + vectors[i] = new float[dims]; + for (int j = 0; j < dims; ++j) { + vectors[i][j] = randomFloat(); + centroid[j] += vectors[i][j]; + } + } + for (int j = 0; j < dims; ++j) { + centroid[j] /= numVectors; + } + // similarity doesn't matter for this test + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT); + float[] scratch = new float[dims]; + for (byte bit : ALL_BITS) { + float eps = (1f / (float) (1 << (bit))); + byte[] destination = new byte[dims]; + for (int i = 0; i < numVectors; ++i) { + System.arraycopy(vectors[i], 0, scratch, 0, dims); + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + + float[] dequantized = deQuantize( + destination, + bit, + new float[] { result.lowerInterval(), result.upperInterval() }, + centroid + ); + float mae = 0; + for (int k = 0; k < dims; ++k) { + mae += Math.abs(dequantized[k] - vectors[i][k]); + } + mae /= dims; + assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps); + } + } + } + public void testAbusiveEdgeCases() { // large zero array for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {