diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java new file mode 100644 index 0000000000000..6e97893354349 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; +import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class Int7ScorerBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "384", "782", "1024" }) + int dims; + + int numVectors = 20 * ES92Int7VectorsScorer.BULK_SIZE; + int numQueries = 5; + + byte[] scratch; + byte[][] binaryVectors; + byte[][] binaryQueries; + float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE]; + + ES92Int7VectorsScorer scorer; + Directory dir; + IndexInput in; + + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + float centroidDp; + + @Setup + public void setup() throws IOException { + binaryVectors = new byte[numVectors][dims]; + dir = new MMapDirectory(Files.createTempDirectory("vectorData")); + try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) { + for (byte[] binaryVector : binaryVectors) { + for (int i = 0; i < dims; i++) { + // 4-bit quantization + binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128); + } + out.writeBytes(binaryVector, 0, binaryVector.length); + ThreadLocalRandom.current().nextBytes(binaryVector); + out.writeBytes(binaryVector, 0, 16); // corrections + } + } + + queryCorrections = new OptimizedScalarQuantizer.QuantizationResult( + ThreadLocalRandom.current().nextFloat(), + ThreadLocalRandom.current().nextFloat(), + ThreadLocalRandom.current().nextFloat(), + Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt()) + ); + centroidDp = ThreadLocalRandom.current().nextFloat(); + + in = dir.openInput("vectors", IOContext.DEFAULT); + binaryQueries = new byte[numVectors][dims]; + for (byte[] binaryVector : binaryVectors) { + for (int i = 0; i < dims; i++) { + // 7-bit quantization + binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128); + } + } + + scratch = new byte[dims]; + scorer = ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(in, dims); + } + + @TearDown + public void teardown() throws IOException { + IOUtils.close(dir, in); + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegment(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + bh.consume( + scorer.score( + binaryQueries[j], + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp + ) + ); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + scorer.scoreBulk( + binaryQueries[j], + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + VectorSimilarityFunction.EUCLIDEAN, + centroidDp, + scores + ); + for (float score : scores) { + bh.consume(score); + } + } + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java new file mode 100644 index 0000000000000..c405e0ad33677 --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** + * Scorer for 7 bit quantized vectors stored in a {@link IndexInput}. + * Queries are expected to be quantized using 7 bits as well. + * */ +public class ES92Int7VectorsScorer { + + public static final int BULK_SIZE = 16; + protected static final float SEVEN_BIT_SCALE = 1f / ((1 << 7) - 1); + + /** The wrapper {@link IndexInput}. */ + protected final IndexInput in; + protected final int dimensions; + + private final float[] lowerIntervals = new float[BULK_SIZE]; + private final float[] upperIntervals = new float[BULK_SIZE]; + private final int[] targetComponentSums = new int[BULK_SIZE]; + private final float[] additionalCorrections = new float[BULK_SIZE]; + + /** Sole constructor, called by sub-classes. */ + public ES92Int7VectorsScorer(IndexInput in, int dimensions) { + this.in = in; + this.dimensions = dimensions; + } + + /** + * compute the quantize distance between the provided quantized query and the quantized vector + * that is read from the wrapped {@link IndexInput}. + */ + public long int7DotProduct(byte[] b) throws IOException { + int total = 0; + for (int i = 0; i < dimensions; i++) { + total += in.readByte() * b[i]; + } + return total; + } + + /** + * compute the quantize distance between the provided quantized query and the quantized vectors + * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is + * determined by {code count} and the results are stored in the provided {@code scores} array. + */ + public void int7DotProductBulk(byte[] b, int count, float[] scores) throws IOException { + for (int i = 0; i < count; i++) { + scores[i] = int7DotProduct(b); + } + } + + /** + * Computes the score by applying the necessary corrections to the provided quantized distance. + */ + public float score( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp + ) throws IOException { + float score = int7DotProduct(q); + in.readFloats(lowerIntervals, 0, 3); + int addition = in.readInt(); + return applyCorrections( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + lowerIntervals[0], + lowerIntervals[1], + addition, + lowerIntervals[2], + score + ); + } + + /** + * compute the distance between the provided quantized query and the quantized vectors that are + * read from the wrapped {@link IndexInput}. + * + *

The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the + * input is as follows: First the quantized vectors are read from the input,then all the lower + * intervals as floats, then all the upper intervals as floats, then all the target component sums + * as shorts, and finally all the additional corrections as floats. + * + *

The results are stored in the provided scores array. + */ + public void scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int7DotProductBulk(q, BULK_SIZE, scores); + in.readFloats(lowerIntervals, 0, BULK_SIZE); + in.readFloats(upperIntervals, 0, BULK_SIZE); + in.readInts(targetComponentSums, 0, BULK_SIZE); + in.readFloats(additionalCorrections, 0, BULK_SIZE); + for (int i = 0; i < BULK_SIZE; i++) { + scores[i] = applyCorrections( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + lowerIntervals[i], + upperIntervals[i], + targetComponentSums[i], + additionalCorrections[i], + scores[i] + ); + } + } + + /** + * Computes the score by applying the necessary corrections to the provided quantized distance. + */ + public float applyCorrections( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float lowerInterval, + float upperInterval, + int targetComponentSum, + float additionalCorrection, + float qcDist + ) { + float ax = lowerInterval; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (upperInterval - ax) * SEVEN_BIT_SCALE; + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; + float y1 = queryComponentSum; + float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryAdditionalCorrection + additionalCorrection - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryAdditionalCorrection + additionalCorrection - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index c3091dcb96882..5b14b39d37fb0 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -51,6 +51,10 @@ public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, i return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension); } + public static ES92Int7VectorsScorer getES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException { + return ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(input, dimension); + } + public static long ipByteBinByte(byte[] q, byte[] d) { if (q.length != d.length * B_QUERY) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java index 5bdd7a724ceda..4c4cd98bdd781 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -12,8 +12,7 @@ import org.apache.lucene.store.IndexInput; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; - -import java.io.IOException; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; final class DefaultESVectorizationProvider extends ESVectorizationProvider { private final ESVectorUtilSupport vectorUtilSupport; @@ -28,12 +27,17 @@ public ESVectorUtilSupport getVectorUtilSupport() { } @Override - public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { + public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) { return new ES91OSQVectorsScorer(input, dimension); } @Override - public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException { + public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) { return new ES91Int4VectorsScorer(input, dimension); } + + @Override + public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) { + return new ES92Int7VectorsScorer(input, dimension); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 719284f48471c..d174c31401f02 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.IndexInput; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import java.io.IOException; import java.util.Objects; @@ -35,6 +36,9 @@ public static ESVectorizationProvider getInstance() { /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */ public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException; + /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */ + public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { return new DefaultESVectorizationProvider(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java new file mode 100644 index 0000000000000..6edf60fff1c83 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -0,0 +1,352 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static java.nio.ByteOrder.LITTLE_ENDIAN; +import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.B2I; +import static jdk.incubator.vector.VectorOperators.B2S; +import static jdk.incubator.vector.VectorOperators.S2I; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ +public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer { + + private static final VectorSpecies BYTE_SPECIES_64 = ByteVector.SPECIES_64; + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + + private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; + private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; + + private static final int VECTOR_BITSIZE; + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies INT_SPECIES; + + static { + // default to platform supported bitsize + VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize(); + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE)); + INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE)); + } + + private final MemorySegment memorySegment; + + public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + @Override + public long int7DotProduct(byte[] q) throws IOException { + assert dimensions == q.length; + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time + if (dimensions >= 16) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (VECTOR_BITSIZE >= 512) { + i += BYTE_SPECIES_128.loopBound(dimensions); + res += dotProductBody512(q, i); + } else if (VECTOR_BITSIZE == 256) { + i += BYTE_SPECIES_64.loopBound(dimensions); + res += dotProductBody256(q, i); + } else { + // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" + i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + res += dotProductBody128(q, i); + } + // scalar tail + while (i < dimensions) { + res += in.readByte() * q[i++]; + } + return res; + } else { + return super.int7DotProduct(q); + } + } + + private int dotProductBody512(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + return acc.reduceLanes(ADD); + } + + private int dotProductBody256(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } + + private int dotProductBody128(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(IntVector.SPECIES_128); + long offset = in.getFilePointer(); + // 4 bytes at a time (re-loading half the vector each time!) + for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { + // load 8 bytes + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // process first "half" only: 16-bit multiply + Vector va16 = va8.convert(B2S, 0); + Vector vb16 = vb8.convert(B2S, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } + + @Override + public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { + assert dimensions == q.length; + // only vectorize if we'll at least enter the loop a single time + if (dimensions >= 16) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (VECTOR_BITSIZE >= 512) { + dotProductBody512Bulk(q, count, scores); + } else if (VECTOR_BITSIZE == 256) { + dotProductBody256Bulk(q, count, scores); + } else { + // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" + dotProductBody128Bulk(q, count, scores); + } + } else { + int7DotProductBulk(q, count, scores); + } + } + + private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_128.loopBound(dimensions); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + int i = 0; + for (; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException { + int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + for (int iter = 0; iter < count; iter++) { + IntVector acc = IntVector.zero(IntVector.SPECIES_128); + long offset = in.getFilePointer(); + // 4 bytes at a time (re-loading half the vector each time!) + int i = 0; + for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) { + // load 8 bytes + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // process first "half" only: 16-bit multiply + Vector va16 = va8.convert(B2S, 0); + Vector vb16 = vb8.convert(B2S, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0)); + } + in.seek(offset + limit); + // reduce + long res = acc.reduceLanes(ADD); + for (; i < dimensions; i++) { + res += in.readByte() * q[i]; + } + scores[iter] = res; + } + } + + @Override + public void scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int7DotProductBulk(q, BULK_SIZE, scores); + applyCorrectionsBulk( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } + + private void applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; + float y1 = queryComponentSum; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax).mul(SEVEN_BIT_SCALE); + var targetComponentSums = IntVector.fromMemorySegment( + INT_SPECIES, + memorySegment, + offset + 8 * BULK_SIZE + i * Integer.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 12 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 16L * BULK_SIZE); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 4708a052b05db..856a0cf94410f 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -15,6 +15,7 @@ import org.elasticsearch.logging.Logger; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import java.io.IOException; import java.util.Locale; @@ -42,6 +43,9 @@ public static ESVectorizationProvider getInstance() { /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */ public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException; + /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */ + public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { final int runtimeVersion = Runtime.version().feature(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index abb75352da2f7..9b798870a4284 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -13,6 +13,8 @@ import org.apache.lucene.store.MemorySegmentAccessInput; import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; +import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; @@ -51,4 +53,16 @@ public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dime } return new ES91Int4VectorsScorer(input, dimension); } + + @Override + public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException { + if (input instanceof MemorySegmentAccessInput msai) { + MemorySegment ms = msai.segmentSliceOrNull(0, input.length()); + if (ms != null) { + return new MemorySegmentES92Int7VectorsScorer(input, dimension, ms); + } + } + return new ES92Int7VectorsScorer(input, dimension); + + } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java new file mode 100644 index 0000000000000..1b60471b33b59 --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal; + +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; + +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Native / panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ +public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer { + + private static final VectorSpecies FLOAT_SPECIES; + private static final VectorSpecies INT_SPECIES; + + static { + // default to platform supported bitsize + final int vectorBitSize = VectorShape.preferredShape().vectorBitSize(); + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(vectorBitSize)); + INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(vectorBitSize)); + } + + private final MemorySegment memorySegment; + + public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + @Override + public long int7DotProduct(byte[] q) throws IOException { + final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions); + final MemorySegment querySegment = MemorySegment.ofArray(q); + final long res = Similarities.dotProduct7u(segment, querySegment, dimensions); + in.skipBytes(dimensions); + return res; + } + + @Override + public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException { + // TODO: can we speed up bulks in native code? + for (int i = 0; i < count; i++) { + scores[i] = int7DotProduct(q); + } + } + + @Override + public void scoreBulk( + byte[] q, + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int7DotProductBulk(q, BULK_SIZE, scores); + applyCorrectionsBulk( + queryLowerInterval, + queryUpperInterval, + queryComponentSum, + queryAdditionalCorrection, + similarityFunction, + centroidDp, + scores + ); + } + + private void applyCorrectionsBulk( + float queryLowerInterval, + float queryUpperInterval, + int queryComponentSum, + float queryAdditionalCorrection, + VectorSimilarityFunction similarityFunction, + float centroidDp, + float[] scores + ) throws IOException { + int limit = FLOAT_SPECIES.loopBound(BULK_SIZE); + int i = 0; + long offset = in.getFilePointer(); + float ay = queryLowerInterval; + float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE; + float y1 = queryComponentSum; + for (; i < limit; i += FLOAT_SPECIES.length()) { + var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); + var lx = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 4 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ).sub(ax).mul(SEVEN_BIT_SCALE); + var targetComponentSums = IntVector.fromMemorySegment( + INT_SPECIES, + memorySegment, + offset + 8 * BULK_SIZE + i * Integer.BYTES, + ByteOrder.LITTLE_ENDIAN + ).convert(VectorOperators.I2F, 0); + var additionalCorrections = FloatVector.fromMemorySegment( + FLOAT_SPECIES, + memorySegment, + offset + 12 * BULK_SIZE + i * Float.BYTES, + ByteOrder.LITTLE_ENDIAN + ); + var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i); + // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * + // qcDist; + var res1 = ax.mul(ay).mul(dimensions); + var res2 = lx.mul(ay).mul(targetComponentSums); + var res3 = ax.mul(ly).mul(y1); + var res4 = lx.mul(ly).mul(qcDist); + var res = res1.add(res2).add(res3).add(res4); + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); + res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0); + res.intoArray(scores, i); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + res.intoArray(scores, i); + // not sure how to do it better + for (int j = 0; j < FLOAT_SPECIES.length(); j++) { + scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + } + } else { + res = res.add(1f).mul(0.5f).max(0); + res.intoArray(scores, i); + } + } + } + in.seek(offset + 16L * BULK_SIZE); + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java new file mode 100644 index 0000000000000..31ef6092539e7 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java @@ -0,0 +1,264 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; + +import java.io.IOException; + +import static org.hamcrest.Matchers.greaterThan; + +public class ES92Int7VectorScorerTests extends BaseVectorizationTests { + + public void testInt7DotProduct() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 100); + final byte[] vector = new byte[dimensions]; + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i++) { + for (int j = 0; j < dimensions; j++) { + vector[j] = (byte) random().nextInt(128); // 7-bit quantization + } + out.writeBytes(vector, 0, dimensions); + } + } + final byte[] query = new byte[dimensions]; + for (int j = 0; j < dimensions; j++) { + query[j] = (byte) random().nextInt(128); // 7-bit quantization + } + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors); + final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors); + final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(slice, dimensions); + final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice2, dimensions); + for (int i = 0; i < numVectors; i++) { + in.readBytes(vector, 0, dimensions); + long val = VectorUtil.dotProduct(vector, query); + assertEquals(val, defaultScorer.int7DotProduct(query)); + assertEquals(val, panamaScorer.int7DotProduct(query)); + assertEquals(in.getFilePointer(), slice.getFilePointer()); + assertEquals(in.getFilePointer(), slice2.getFilePointer()); + } + assertEquals((long) dimensions * numVectors, in.getFilePointer()); + } + } + } + + public void testInt7Score() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 100); + + float[][] vectors = new float[numVectors][dimensions]; + final int[] scratch = new int[dimensions]; + final byte[] qVector = new byte[dimensions]; + final float[] centroid = new float[dimensions]; + VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + randomVector(centroid, similarityFunction); + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (float[] vector : vectors) { + randomVector(vector, similarityFunction); + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize( + vector.clone(), + scratch, + (byte) 7, + centroid + ); + for (int j = 0; j < dimensions; j++) { + qVector[j] = (byte) scratch[j]; + } + out.writeBytes(qVector, 0, dimensions); + out.writeInt(Float.floatToIntBits(result.lowerInterval())); + out.writeInt(Float.floatToIntBits(result.upperInterval())); + out.writeInt(Float.floatToIntBits(result.additionalCorrection())); + out.writeInt(result.quantizedComponentSum()); + } + } + final float[] query = new float[dimensions]; + randomVector(query, similarityFunction); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize( + query.clone(), + scratch, + (byte) 7, + centroid + ); + byte[] qQuery = new byte[dimensions]; + for (int i = 0; i < dimensions; i++) { + qQuery[i] = (byte) scratch[i]; + } + + float centroidDp = VectorUtil.dotProduct(centroid, centroid); + + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors); + final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions); + final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions); + for (int i = 0; i < numVectors; i++) { + float scoreDefault = defaultScorer.score( + qQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + float scorePanama = panamaScorer.score( + qQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + assertEquals(scoreDefault, scorePanama, 0.001f); + float realSimilarity = similarityFunction.compare(vectors[i], query); + float accuracy = realSimilarity > scoreDefault ? scoreDefault / realSimilarity : realSimilarity / scoreDefault; + assertThat(accuracy, greaterThan(0.98f)); + assertEquals(in.getFilePointer(), slice.getFilePointer()); + } + assertEquals((long) (dimensions + 16) * numVectors, in.getFilePointer()); + } + } + } + + public void testInt7ScoreBulk() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE; + final float[][] vectors = new float[numVectors][dimensions]; + final int[] quantizedScratch = new int[dimensions]; + final byte[] quantizeVector = new byte[dimensions]; + final float[] centroid = new float[dimensions]; + VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values()); + randomVector(centroid, similarityFunction); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + OptimizedScalarQuantizer.QuantizationResult[] results = + new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE]; + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) { + randomVector(vectors[i + j], similarityFunction); + results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 7, centroid); + for (int k = 0; k < dimensions; k++) { + quantizeVector[k] = (byte) quantizedScratch[k]; + } + out.writeBytes(quantizeVector, 0, dimensions); + } + writeCorrections(results, out); + } + } + final float[] query = new float[dimensions]; + final byte[] quantizeQuery = new byte[dimensions]; + randomVector(query, similarityFunction); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize( + query.clone(), + quantizedScratch, + (byte) 7, + centroid + ); + for (int j = 0; j < dimensions; j++) { + quantizeQuery[j] = (byte) quantizedScratch[j]; + } + float centroidDp = VectorUtil.dotProduct(centroid, centroid); + + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors); + final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions); + final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions); + float[] scoresDefault = new float[ES91Int4VectorsScorer.BULK_SIZE]; + float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE]; + for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) { + defaultScorer.scoreBulk( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scoresDefault + ); + panamaScorer.scoreBulk( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scoresPanama + ); + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f); + float realSimilarity = similarityFunction.compare(vectors[i + j], query); + float accuracy = realSimilarity > scoresDefault[j] + ? scoresDefault[j] / realSimilarity + : realSimilarity / scoresDefault[j]; + assertThat(accuracy, greaterThan(0.98f)); + } + assertEquals(in.getFilePointer(), slice.getFilePointer()); + } + assertEquals((long) (dimensions + 16) * numVectors, in.getFilePointer()); + } + } + } + + private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.lowerInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.upperInterval())); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + int targetComponentSum = correction.quantizedComponentSum(); + out.writeInt(targetComponentSum); + } + for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { + out.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + } + } + + private void randomVector(float[] vector, VectorSimilarityFunction vectorSimilarityFunction) { + for (int i = 0; i < vector.length; i++) { + vector[i] = random().nextFloat(); + } + if (vectorSimilarityFunction != VectorSimilarityFunction.EUCLIDEAN) { + VectorUtil.l2normalize(vector); + } + } +}