diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java index 356353605..3cfe950c1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java @@ -157,4 +157,15 @@ public int hashCode() { public String toString() { return "BinaryQuantization"; } + + @Override + public double reconstructionError(VectorFloat vector) { + double sum = 0; + for (int i = 0; i < vector.length(); i++) { + boolean bit = vector.get(i) > 0; + double diff = (bit ? 1.0f : 0.0f) - vector.get(i); + sum += diff * diff; + } + return sum / vector.length(); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java index 95c048df8..093f5ad3c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java @@ -30,10 +30,15 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.concurrent.ForkJoinPool; +import java.util.stream.Collectors; import java.util.stream.IntStream; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; +import static io.github.jbellis.jvector.vector.VectorUtil.sub; + /** * Non-uniform Vector Quantization for float vectors. @@ -355,6 +360,37 @@ public String toString() { return String.format("NVQuantization(sub-vectors=%d)", subvectorSizesAndOffsets.length); } + @Override + public double reconstructionError(VectorFloat vector) { + final var encodedVector = QuantizedVector.createEmpty(subvectorSizesAndOffsets, bitsPerDimension); + + final var tempVector = VectorUtil.sub(vector, globalMean); + QuantizedVector.quantizeTo(getSubVectors(tempVector), bitsPerDimension, learn, encodedVector); + final var vectorSubVectors = this.getSubVectors(tempVector); + + float dist = 0; + switch (this.bitsPerDimension) { + case EIGHT: + for (VectorFloat querySubVector : vectorSubVectors) { + VectorUtil.nvqShuffleQueryInPlace8bit(querySubVector); + } + + for (int i = 0; i < vectorSubVectors.length; i++) { + var svDB = encodedVector.subVectors[i]; + dist += VectorUtil.nvqSquareL2Distance8bit( + vectorSubVectors[i], + svDB.bytes, svDB.growthRate, svDB.midpoint, + svDB.minValue, svDB.maxValue + ); + } + break; + default: + throw new IllegalArgumentException("Unsupported bits per dimension " + this.bitsPerDimension); + } + + return dist / vector.length(); + } + /** * A NuVeQ vector. */ diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index 0e98bbf32..904f0dcd9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -40,7 +40,9 @@ import java.util.function.Supplier; import java.util.logging.Logger; import java.util.stream.Collectors; +import java.util.stream.DoubleStream; import java.util.stream.IntStream; +import java.util.stream.Stream; import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; import static io.github.jbellis.jvector.util.MathUtil.square; @@ -769,4 +771,28 @@ private static void checkClusterCount(int clusterCount) { public int getOriginalDimension() { return originalDimension; } + + @Override + public double reconstructionError(VectorFloat vector) { + var code = vectorTypeSupport.createByteSequence(M); + + if (globalCentroid != null) { + vector = sub(vector, globalCentroid); + } + + if (anisotropicThreshold > UNWEIGHTED) + encodeAnisotropic(vector, code); + else + encodeUnweighted(vector, code); + + float sum = 0; + for (int m = 0; m < M; m++) { + int centroidIndex = Byte.toUnsignedInt(code.get(m)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + sum += VectorUtil.squareL2Distance(codebooks[m], centroidIndex * centroidLength, vector, centroidOffset, centroidLength); + } + + return sum / vector.length(); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java index 09eb1e035..3b13a0282 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java @@ -26,6 +26,10 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.ForkJoinPool; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.Stream; /** * Interface for vector compression. T is the encoded (compressed) vector type; @@ -74,4 +78,38 @@ default void write(DataOutput out) throws IOException { /** the size of a compressed vector */ int compressedVectorSize(); + + /** + * Compute the mean squared error (MSE) for the given vector. + *

+ * MSE = (sum of squared errors over all dimensions) / (number of dimensions) + * @param vector the vector to compute the reconstruction error for + * @return the reconstruction error for the given vector + */ + double reconstructionError(VectorFloat vector); + + /** + * Compute the mean squared error (MSE) for each vector in the stream. + *

+ * MSE = (sum of squared errors over all dimensions) / (number of dimensions) + * @param ravv the vectors to compute the reconstruction error for + * @return the reconstruction error for each vector + */ + default double[] reconstructionErrors(RandomAccessVectorValues ravv) { + return IntStream.range(0, ravv.size()).mapToDouble(i -> reconstructionError(ravv.getVector(i))).toArray(); + } + + /** + * Compute the mean squared error (MSE) for each vector in the stream in parallel. + *

+ * MSE = (sum of squared errors over all dimensions) / (number of dimensions) + * @param ravv the vectors to compute the reconstruction error for + * @param simdExecutor the ForkJoinPool to use for SIMD operations + * @return the reconstruction error for each vector + */ + default double[] reconstructionErrors(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { + return simdExecutor.submit(() -> + IntStream.range(0, ravv.size()).mapToDouble(i -> reconstructionError(ravv.getVector(i))).toArray() + ).join(); + } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestReconstructionError.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestReconstructionError.java new file mode 100644 index 000000000..49af22f89 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestReconstructionError.java @@ -0,0 +1,114 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.OptionalDouble; + +import static io.github.jbellis.jvector.TestUtil.createRandomVectors; +import static org.junit.Assert.assertEquals; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestReconstructionError extends RandomizedTest { + private Path testDirectory; + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + @Test + public void testReconstructionError_withProductQuantization() { + testReconstructionError_withProductQuantization(1_000, 1.15, 2.5); + testReconstructionError_withProductQuantization(10_000, 0.14, 0.29); + } + + public void testReconstructionError_withProductQuantization(int nVectors, double toleranceAvg, double toleranceSTD) { + int dimensions = 32; + var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + + ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true); + + compareErrors(pq, ravv, ravvTest, toleranceAvg, toleranceSTD); + } + + @Test + public void testReconstructionError_withBinaryQuantization() { + testReconstructionError_withBinaryQuantization(1_000, 0.05, 0.25); + testReconstructionError_withBinaryQuantization(10_000, 0.008, 0.09); + } + + public void testReconstructionError_withBinaryQuantization(int nVectors, double toleranceAvg, double toleranceSTD) { + int dimensions = 32; + var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + + BinaryQuantization bq = new BinaryQuantization(dimensions); + + compareErrors(bq, ravv, ravvTest, toleranceAvg, toleranceSTD); + } + + @Test + public void testReconstructionError_withNVQuantization() { + testReconstructionError_withBinaryQuantization(1_000, 4e-2, 0.25); + testReconstructionError_withBinaryQuantization(10_000, 1e-2, 0.08); + } + + public void testReconstructionError_withNVQuantization(int nVectors, double toleranceAvg, double toleranceSTD) { + int dimensions = 32; + var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions); + + NVQuantization nvq = NVQuantization.compute(ravv, 2); + + compareErrors(nvq, ravv, ravvTest, toleranceAvg, toleranceSTD); + } + + void compareErrors(VectorCompressor compressor, RandomAccessVectorValues sample1, RandomAccessVectorValues sample2, double toleranceAvg, double toleranceSTD) { + double[] errors1 = compressor.reconstructionErrors(sample1); + double averageError1 = Arrays.stream(errors1).average().getAsDouble(); + double varError1 = Arrays.stream(errors1).map(x -> (x - averageError1) * (x - averageError1)).average().getAsDouble(); + + double[] errors2 = compressor.reconstructionErrors(sample2); + double averageError2 = Arrays.stream(errors2).average().getAsDouble(); + double varError2 = Arrays.stream(errors2).map(x -> (x - averageError2) * (x - averageError2)).average().getAsDouble(); + + // check relative error + assertEquals(1, averageError2 / averageError1, toleranceAvg); + assertEquals(1, varError2 / varError1, toleranceSTD); + } +}