Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,15 @@ public int hashCode() {
public String toString() {
return "BinaryQuantization";
}

@Override
public double reconstructionError(VectorFloat<?> vector) {
double sum = 0;
for (int i = 0; i < vector.length(); i++) {
boolean bit = vector.get(i) > 0;
double diff = (bit ? 1.0f : 0.0f) - vector.get(i);
sum += diff * diff;
}
return sum / vector.length();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
import static io.github.jbellis.jvector.vector.VectorUtil.sub;


/**
* Non-uniform Vector Quantization for float vectors.
Expand Down Expand Up @@ -355,6 +360,37 @@ public String toString() {
return String.format("NVQuantization(sub-vectors=%d)", subvectorSizesAndOffsets.length);
}

@Override
public double reconstructionError(VectorFloat<?> vector) {
final var encodedVector = QuantizedVector.createEmpty(subvectorSizesAndOffsets, bitsPerDimension);

final var tempVector = VectorUtil.sub(vector, globalMean);
QuantizedVector.quantizeTo(getSubVectors(tempVector), bitsPerDimension, learn, encodedVector);
final var vectorSubVectors = this.getSubVectors(tempVector);

float dist = 0;
switch (this.bitsPerDimension) {
case EIGHT:
for (VectorFloat<?> querySubVector : vectorSubVectors) {
VectorUtil.nvqShuffleQueryInPlace8bit(querySubVector);
}

for (int i = 0; i < vectorSubVectors.length; i++) {
var svDB = encodedVector.subVectors[i];
dist += VectorUtil.nvqSquareL2Distance8bit(
vectorSubVectors[i],
svDB.bytes, svDB.growthRate, svDB.midpoint,
svDB.minValue, svDB.maxValue
);
}
break;
default:
throw new IllegalArgumentException("Unsupported bits per dimension " + this.bitsPerDimension);
}

return dist / vector.length();
}

/**
* A NuVeQ vector.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import java.util.function.Supplier;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
import static io.github.jbellis.jvector.util.MathUtil.square;
Expand Down Expand Up @@ -769,4 +771,28 @@ private static void checkClusterCount(int clusterCount) {
public int getOriginalDimension() {
return originalDimension;
}

@Override
public double reconstructionError(VectorFloat<?> vector) {
var code = vectorTypeSupport.createByteSequence(M);

if (globalCentroid != null) {
vector = sub(vector, globalCentroid);
}

if (anisotropicThreshold > UNWEIGHTED)
encodeAnisotropic(vector, code);
else
encodeUnweighted(vector, code);

float sum = 0;
for (int m = 0; m < M; m++) {
int centroidIndex = Byte.toUnsignedInt(code.get(m));
int centroidLength = subvectorSizesAndOffsets[m][0];
int centroidOffset = subvectorSizesAndOffsets[m][1];
sum += VectorUtil.squareL2Distance(codebooks[m], centroidIndex * centroidLength, vector, centroidOffset, centroidLength);
}

return sum / vector.length();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
* Interface for vector compression. T is the encoded (compressed) vector type;
Expand Down Expand Up @@ -74,4 +78,38 @@ default void write(DataOutput out) throws IOException {

/** the size of a compressed vector */
int compressedVectorSize();

/**
* Compute the mean squared error (MSE) for the given vector.
* </p>
* MSE = (sum of squared errors over all dimensions) / (number of dimensions)
* @param vector the vector to compute the reconstruction error for
* @return the reconstruction error for the given vector
*/
double reconstructionError(VectorFloat<?> vector);

/**
* Compute the mean squared error (MSE) for each vector in the stream.
* </p>
* MSE = (sum of squared errors over all dimensions) / (number of dimensions)
* @param ravv the vectors to compute the reconstruction error for
* @return the reconstruction error for each vector
*/
default double[] reconstructionErrors(RandomAccessVectorValues ravv) {
return IntStream.range(0, ravv.size()).mapToDouble(i -> reconstructionError(ravv.getVector(i))).toArray();
}

/**
* Compute the mean squared error (MSE) for each vector in the stream in parallel.
* </p>
* MSE = (sum of squared errors over all dimensions) / (number of dimensions)
* @param ravv the vectors to compute the reconstruction error for
* @param simdExecutor the ForkJoinPool to use for SIMD operations
* @return the reconstruction error for each vector
*/
default double[] reconstructionErrors(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
return simdExecutor.submit(() ->
IntStream.range(0, ravv.size()).mapToDouble(i -> reconstructionError(ravv.getVector(i))).toArray()
).join();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.quantization;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.OptionalDouble;

import static io.github.jbellis.jvector.TestUtil.createRandomVectors;
import static org.junit.Assert.assertEquals;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class TestReconstructionError extends RandomizedTest {
private Path testDirectory;

@Before
public void setup() throws IOException {
testDirectory = Files.createTempDirectory(this.getClass().getSimpleName());
}

@After
public void tearDown() {
TestUtil.deleteQuietly(testDirectory);
}

@Test
public void testReconstructionError_withProductQuantization() {
testReconstructionError_withProductQuantization(1_000, 1.15, 2.5);
testReconstructionError_withProductQuantization(10_000, 0.14, 0.29);
}

public void testReconstructionError_withProductQuantization(int nVectors, double toleranceAvg, double toleranceSTD) {
int dimensions = 32;
var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);
var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);

ProductQuantization pq = ProductQuantization.compute(ravv, 8, 256, true);

compareErrors(pq, ravv, ravvTest, toleranceAvg, toleranceSTD);
}

@Test
public void testReconstructionError_withBinaryQuantization() {
testReconstructionError_withBinaryQuantization(1_000, 0.05, 0.25);
testReconstructionError_withBinaryQuantization(10_000, 0.008, 0.09);
}

public void testReconstructionError_withBinaryQuantization(int nVectors, double toleranceAvg, double toleranceSTD) {
int dimensions = 32;
var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);
var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);

BinaryQuantization bq = new BinaryQuantization(dimensions);

compareErrors(bq, ravv, ravvTest, toleranceAvg, toleranceSTD);
}

@Test
public void testReconstructionError_withNVQuantization() {
testReconstructionError_withBinaryQuantization(1_000, 4e-2, 0.25);
testReconstructionError_withBinaryQuantization(10_000, 1e-2, 0.08);
}

public void testReconstructionError_withNVQuantization(int nVectors, double toleranceAvg, double toleranceSTD) {
int dimensions = 32;
var ravv = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);
var ravvTest = new ListRandomAccessVectorValues(createRandomVectors(nVectors, dimensions), dimensions);

NVQuantization nvq = NVQuantization.compute(ravv, 2);

compareErrors(nvq, ravv, ravvTest, toleranceAvg, toleranceSTD);
}

void compareErrors(VectorCompressor<?> compressor, RandomAccessVectorValues sample1, RandomAccessVectorValues sample2, double toleranceAvg, double toleranceSTD) {
double[] errors1 = compressor.reconstructionErrors(sample1);
double averageError1 = Arrays.stream(errors1).average().getAsDouble();
double varError1 = Arrays.stream(errors1).map(x -> (x - averageError1) * (x - averageError1)).average().getAsDouble();

double[] errors2 = compressor.reconstructionErrors(sample2);
double averageError2 = Arrays.stream(errors2).average().getAsDouble();
double varError2 = Arrays.stream(errors2).map(x -> (x - averageError2) * (x - averageError2)).average().getAsDouble();

// check relative error
assertEquals(1, averageError2 / averageError1, toleranceAvg);
assertEquals(1, varError2 / varError1, toleranceSTD);
}
}
Loading