Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ public float applyCorrections(
float qcDist
) {
float ax = lowerInterval;
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = upperInterval - ax;
float lx = (upperInterval - ax) * FOUR_BIT_SCALE;
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
float y1 = queryComponentSum;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ private void applyCorrectionsBulk(
memorySegment,
offset + 4 * BULK_SIZE + i * Float.BYTES,
ByteOrder.LITTLE_ENDIAN
).sub(ax);
).sub(ax).mul(FOUR_BIT_SCALE);
var targetComponentSums = ShortVector.fromMemorySegment(
SHORT_SPECIES,
memorySegment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import static org.hamcrest.Matchers.lessThan;
import java.io.IOException;

import static org.hamcrest.Matchers.greaterThan;

public class ES91Int4VectorScorerTests extends BaseVectorizationTests {

Expand Down Expand Up @@ -130,31 +132,59 @@ public void testInt4ScoreBulk() throws Exception {
// only even dimensions are supported
final int dimensions = random().nextInt(1, 1000) * 2;
final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE;
final byte[] vector = new byte[ES91Int4VectorsScorer.BULK_SIZE * dimensions];
final byte[] corrections = new byte[ES91Int4VectorsScorer.BULK_SIZE * 14];
final float[][] vectors = new float[numVectors][dimensions];
final int[] quantizedScratch = new int[dimensions];
final byte[] quantizeVector = new byte[dimensions];
final float[] centroid = new float[dimensions];
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
for (int i = 0; i < dimensions; i++) {
centroid[i] = random().nextFloat();
}
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
VectorUtil.l2normalize(centroid);
}

OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
try (Directory dir = new MMapDirectory(createTempDir())) {
try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
OptimizedScalarQuantizer.QuantizationResult[] results =
new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE * dimensions; j++) {
vector[j] = (byte) random().nextInt(16); // 4-bit quantization
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
for (int k = 0; k < dimensions; k++) {
vectors[i + j][k] = random().nextFloat();
}
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
VectorUtil.l2normalize(vectors[i + j]);
}
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 4, centroid);
for (int k = 0; k < dimensions; k++) {
quantizeVector[k] = (byte) quantizedScratch[k];
}
out.writeBytes(quantizeVector, 0, dimensions);
}
out.writeBytes(vector, 0, vector.length);
random().nextBytes(corrections);
out.writeBytes(corrections, 0, corrections.length);
writeCorrections(results, out);
}
}
final byte[] query = new byte[dimensions];
final float[] query = new float[dimensions];
final byte[] quantizeQuery = new byte[dimensions];
for (int j = 0; j < dimensions; j++) {
query[j] = (byte) random().nextInt(16); // 4-bit quantization
query[j] = random().nextFloat();
}
OptimizedScalarQuantizer.QuantizationResult queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
random().nextFloat(),
random().nextFloat(),
random().nextFloat(),
Short.toUnsignedInt((short) random().nextInt())
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
VectorUtil.l2normalize(query);
}
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
query.clone(),
quantizedScratch,
(byte) 4,
centroid
);
float centroidDp = random().nextFloat();
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
for (int j = 0; j < dimensions; j++) {
quantizeQuery[j] = (byte) quantizedScratch[j];
}
float centroidDp = VectorUtil.dotProduct(centroid, centroid);

try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
// Work on a slice that has just the right number of bytes to make the test fail with an
// index-out-of-bounds in case the implementation reads more than the allowed number of
Expand All @@ -166,7 +196,7 @@ public void testInt4ScoreBulk() throws Exception {
float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE];
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
defaultScorer.scoreBulk(
query,
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
Expand All @@ -176,7 +206,7 @@ public void testInt4ScoreBulk() throws Exception {
scoresDefault
);
panamaScorer.scoreBulk(
query,
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
Expand All @@ -186,29 +216,34 @@ public void testInt4ScoreBulk() throws Exception {
scoresPanama
);
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
if (scoresDefault[j] == scoresPanama[j]) {
continue;
}
if (scoresDefault[j] > (1000 * Byte.MAX_VALUE)) {
float diff = Math.abs(scoresDefault[j] - scoresPanama[j]);
assertThat(
"defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j],
diff / scoresDefault[j],
lessThan(1e-5f)
);
assertThat(
"defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j],
diff / scoresPanama[j],
lessThan(1e-5f)
);
} else {
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
}
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
float realSimilarity = similarityFunction.compare(vectors[i + j], query);
float accuracy = realSimilarity > scoresDefault[j]
? scoresDefault[j] / realSimilarity
: realSimilarity / scoresDefault[j];
assertThat(accuracy, greaterThan(0.90f));
}
assertEquals(in.getFilePointer(), slice.getFilePointer());
}
assertEquals((long) (dimensions + 14) * numVectors, in.getFilePointer());
}
}
}

private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
int targetComponentSum = correction.quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
out.writeShort((short) targetComponentSum);
}
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
}
}
}