diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java
new file mode 100644
index 0000000000000..6e97893354349
--- /dev/null
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7ScorerBenchmark.java
@@ -0,0 +1,160 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
+import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.Throughput)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Benchmark)
+// first iteration is complete garbage, so make sure we really warmup
+@Warmup(iterations = 4, time = 1)
+// real iterations. not useful to spend tons of time here, better to fork more
+@Measurement(iterations = 5, time = 1)
+// engage some noise reduction
+@Fork(value = 1)
+public class Int7ScorerBenchmark {
+
+ static {
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ @Param({ "384", "782", "1024" })
+ int dims;
+
+ int numVectors = 20 * ES92Int7VectorsScorer.BULK_SIZE;
+ int numQueries = 5;
+
+ byte[] scratch;
+ byte[][] binaryVectors;
+ byte[][] binaryQueries;
+ float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
+
+ ES92Int7VectorsScorer scorer;
+ Directory dir;
+ IndexInput in;
+
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections;
+ float centroidDp;
+
+ @Setup
+ public void setup() throws IOException {
+ binaryVectors = new byte[numVectors][dims];
+ dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
+ try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
+ for (byte[] binaryVector : binaryVectors) {
+ for (int i = 0; i < dims; i++) {
+ // 4-bit quantization
+ binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
+ }
+ out.writeBytes(binaryVector, 0, binaryVector.length);
+ ThreadLocalRandom.current().nextBytes(binaryVector);
+ out.writeBytes(binaryVector, 0, 16); // corrections
+ }
+ }
+
+ queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
+ ThreadLocalRandom.current().nextFloat(),
+ ThreadLocalRandom.current().nextFloat(),
+ ThreadLocalRandom.current().nextFloat(),
+ Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt())
+ );
+ centroidDp = ThreadLocalRandom.current().nextFloat();
+
+ in = dir.openInput("vectors", IOContext.DEFAULT);
+ binaryQueries = new byte[numVectors][dims];
+ for (byte[] binaryVector : binaryVectors) {
+ for (int i = 0; i < dims; i++) {
+ // 7-bit quantization
+ binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
+ }
+ }
+
+ scratch = new byte[dims];
+ scorer = ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(in, dims);
+ }
+
+ @TearDown
+ public void teardown() throws IOException {
+ IOUtils.close(dir, in);
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromMemorySegment(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i++) {
+ bh.consume(
+ scorer.score(
+ binaryQueries[j],
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ VectorSimilarityFunction.EUCLIDEAN,
+ centroidDp
+ )
+ );
+ }
+ }
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
+ scorer.scoreBulk(
+ binaryQueries[j],
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ VectorSimilarityFunction.EUCLIDEAN,
+ centroidDp,
+ scores
+ );
+ for (float score : scores) {
+ bh.consume(score);
+ }
+ }
+ }
+ }
+}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java
new file mode 100644
index 0000000000000..c405e0ad33677
--- /dev/null
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.VectorUtil;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/**
+ * Scorer for 7 bit quantized vectors stored in a {@link IndexInput}.
+ * Queries are expected to be quantized using 7 bits as well.
+ * */
+public class ES92Int7VectorsScorer {
+
+ public static final int BULK_SIZE = 16;
+ protected static final float SEVEN_BIT_SCALE = 1f / ((1 << 7) - 1);
+
+ /** The wrapper {@link IndexInput}. */
+ protected final IndexInput in;
+ protected final int dimensions;
+
+ private final float[] lowerIntervals = new float[BULK_SIZE];
+ private final float[] upperIntervals = new float[BULK_SIZE];
+ private final int[] targetComponentSums = new int[BULK_SIZE];
+ private final float[] additionalCorrections = new float[BULK_SIZE];
+
+ /** Sole constructor, called by sub-classes. */
+ public ES92Int7VectorsScorer(IndexInput in, int dimensions) {
+ this.in = in;
+ this.dimensions = dimensions;
+ }
+
+ /**
+ * compute the quantize distance between the provided quantized query and the quantized vector
+ * that is read from the wrapped {@link IndexInput}.
+ */
+ public long int7DotProduct(byte[] b) throws IOException {
+ int total = 0;
+ for (int i = 0; i < dimensions; i++) {
+ total += in.readByte() * b[i];
+ }
+ return total;
+ }
+
+ /**
+ * compute the quantize distance between the provided quantized query and the quantized vectors
+ * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
+ * determined by {code count} and the results are stored in the provided {@code scores} array.
+ */
+ public void int7DotProductBulk(byte[] b, int count, float[] scores) throws IOException {
+ for (int i = 0; i < count; i++) {
+ scores[i] = int7DotProduct(b);
+ }
+ }
+
+ /**
+ * Computes the score by applying the necessary corrections to the provided quantized distance.
+ */
+ public float score(
+ byte[] q,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp
+ ) throws IOException {
+ float score = int7DotProduct(q);
+ in.readFloats(lowerIntervals, 0, 3);
+ int addition = in.readInt();
+ return applyCorrections(
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ lowerIntervals[0],
+ lowerIntervals[1],
+ addition,
+ lowerIntervals[2],
+ score
+ );
+ }
+
+ /**
+ * compute the distance between the provided quantized query and the quantized vectors that are
+ * read from the wrapped {@link IndexInput}.
+ *
+ *
The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
+ * input is as follows: First the quantized vectors are read from the input,then all the lower
+ * intervals as floats, then all the upper intervals as floats, then all the target component sums
+ * as shorts, and finally all the additional corrections as floats.
+ *
+ *
The results are stored in the provided scores array.
+ */
+ public void scoreBulk(
+ byte[] q,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ int7DotProductBulk(q, BULK_SIZE, scores);
+ in.readFloats(lowerIntervals, 0, BULK_SIZE);
+ in.readFloats(upperIntervals, 0, BULK_SIZE);
+ in.readInts(targetComponentSums, 0, BULK_SIZE);
+ in.readFloats(additionalCorrections, 0, BULK_SIZE);
+ for (int i = 0; i < BULK_SIZE; i++) {
+ scores[i] = applyCorrections(
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ lowerIntervals[i],
+ upperIntervals[i],
+ targetComponentSums[i],
+ additionalCorrections[i],
+ scores[i]
+ );
+ }
+ }
+
+ /**
+ * Computes the score by applying the necessary corrections to the provided quantized distance.
+ */
+ public float applyCorrections(
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float lowerInterval,
+ float upperInterval,
+ int targetComponentSum,
+ float additionalCorrection,
+ float qcDist
+ ) {
+ float ax = lowerInterval;
+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+ float lx = (upperInterval - ax) * SEVEN_BIT_SCALE;
+ float ay = queryLowerInterval;
+ float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
+ float y1 = queryComponentSum;
+ float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score = queryAdditionalCorrection + additionalCorrection - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score += queryAdditionalCorrection + additionalCorrection - centroidDp;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+ }
+}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
index c3091dcb96882..5b14b39d37fb0 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
@@ -51,6 +51,10 @@ public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, i
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
}
+ public static ES92Int7VectorsScorer getES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException {
+ return ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(input, dimension);
+ }
+
public static long ipByteBinByte(byte[] q, byte[] d) {
if (q.length != d.length * B_QUERY) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
index 5bdd7a724ceda..4c4cd98bdd781 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
@@ -12,8 +12,7 @@
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
-
-import java.io.IOException;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
private final ESVectorUtilSupport vectorUtilSupport;
@@ -28,12 +27,17 @@ public ESVectorUtilSupport getVectorUtilSupport() {
}
@Override
- public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
+ public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) {
return new ES91OSQVectorsScorer(input, dimension);
}
@Override
- public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
+ public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) {
return new ES91Int4VectorsScorer(input, dimension);
}
+
+ @Override
+ public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) {
+ return new ES92Int7VectorsScorer(input, dimension);
+ }
}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index 719284f48471c..d174c31401f02 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -12,6 +12,7 @@
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import java.io.IOException;
import java.util.Objects;
@@ -35,6 +36,9 @@ public static ESVectorizationProvider getInstance() {
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
+ /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
+ public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException;
+
// visible for tests
static ESVectorizationProvider lookup(boolean testMode) {
return new DefaultESVectorizationProvider();
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
new file mode 100644
index 0000000000000..6edf60fff1c83
--- /dev/null
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
@@ -0,0 +1,352 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal;
+
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.Vector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorShape;
+import jdk.incubator.vector.VectorSpecies;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
+
+import static java.nio.ByteOrder.LITTLE_ENDIAN;
+import static jdk.incubator.vector.VectorOperators.ADD;
+import static jdk.incubator.vector.VectorOperators.B2I;
+import static jdk.incubator.vector.VectorOperators.B2S;
+import static jdk.incubator.vector.VectorOperators.S2I;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
+public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer {
+
+ private static final VectorSpecies BYTE_SPECIES_64 = ByteVector.SPECIES_64;
+ private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+
+ private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128;
+ private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256;
+
+ private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128;
+ private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256;
+ private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512;
+
+ private static final int VECTOR_BITSIZE;
+ private static final VectorSpecies FLOAT_SPECIES;
+ private static final VectorSpecies INT_SPECIES;
+
+ static {
+ // default to platform supported bitsize
+ VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
+ FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
+ INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
+ }
+
+ private final MemorySegment memorySegment;
+
+ public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
+ super(in, dimensions);
+ this.memorySegment = memorySegment;
+ }
+
+ @Override
+ public long int7DotProduct(byte[] q) throws IOException {
+ assert dimensions == q.length;
+ int i = 0;
+ int res = 0;
+ // only vectorize if we'll at least enter the loop a single time
+ if (dimensions >= 16) {
+ // compute vectorized dot product consistent with VPDPBUSD instruction
+ if (VECTOR_BITSIZE >= 512) {
+ i += BYTE_SPECIES_128.loopBound(dimensions);
+ res += dotProductBody512(q, i);
+ } else if (VECTOR_BITSIZE == 256) {
+ i += BYTE_SPECIES_64.loopBound(dimensions);
+ res += dotProductBody256(q, i);
+ } else {
+ // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
+ i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
+ res += dotProductBody128(q, i);
+ }
+ // scalar tail
+ while (i < dimensions) {
+ res += in.readByte() * q[i++];
+ }
+ return res;
+ } else {
+ return super.int7DotProduct(q);
+ }
+ }
+
+ private int dotProductBody512(byte[] q, int limit) throws IOException {
+ IntVector acc = IntVector.zero(INT_SPECIES_512);
+ long offset = in.getFilePointer();
+ for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
+ Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
+ Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
+ Vector prod16 = va16.mul(vb16);
+
+ // 32-bit add
+ Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
+ acc = acc.add(prod32);
+ }
+
+ in.seek(offset + limit); // advance the input stream
+ // reduce
+ return acc.reduceLanes(ADD);
+ }
+
+ private int dotProductBody256(byte[] q, int limit) throws IOException {
+ IntVector acc = IntVector.zero(INT_SPECIES_256);
+ long offset = in.getFilePointer();
+ for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // 32-bit multiply and add into accumulator
+ Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
+ Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
+ acc = acc.add(va32.mul(vb32));
+ }
+ in.seek(offset + limit);
+ // reduce
+ return acc.reduceLanes(ADD);
+ }
+
+ private int dotProductBody128(byte[] q, int limit) throws IOException {
+ IntVector acc = IntVector.zero(IntVector.SPECIES_128);
+ long offset = in.getFilePointer();
+ // 4 bytes at a time (re-loading half the vector each time!)
+ for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
+ // load 8 bytes
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // process first "half" only: 16-bit multiply
+ Vector va16 = va8.convert(B2S, 0);
+ Vector vb16 = vb8.convert(B2S, 0);
+ Vector prod16 = va16.mul(vb16);
+
+ // 32-bit add
+ acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
+ }
+ in.seek(offset + limit);
+ // reduce
+ return acc.reduceLanes(ADD);
+ }
+
+ @Override
+ public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
+ assert dimensions == q.length;
+ // only vectorize if we'll at least enter the loop a single time
+ if (dimensions >= 16) {
+ // compute vectorized dot product consistent with VPDPBUSD instruction
+ if (VECTOR_BITSIZE >= 512) {
+ dotProductBody512Bulk(q, count, scores);
+ } else if (VECTOR_BITSIZE == 256) {
+ dotProductBody256Bulk(q, count, scores);
+ } else {
+ // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
+ dotProductBody128Bulk(q, count, scores);
+ }
+ } else {
+ int7DotProductBulk(q, count, scores);
+ }
+ }
+
+ private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
+ int limit = BYTE_SPECIES_128.loopBound(dimensions);
+ for (int iter = 0; iter < count; iter++) {
+ IntVector acc = IntVector.zero(INT_SPECIES_512);
+ long offset = in.getFilePointer();
+ int i = 0;
+ for (; i < limit; i += BYTE_SPECIES_128.length()) {
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
+ Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
+ Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
+ Vector prod16 = va16.mul(vb16);
+
+ // 32-bit add
+ Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
+ acc = acc.add(prod32);
+ }
+
+ in.seek(offset + limit); // advance the input stream
+ // reduce
+ long res = acc.reduceLanes(ADD);
+ for (; i < dimensions; i++) {
+ res += in.readByte() * q[i];
+ }
+ scores[iter] = res;
+ }
+ }
+
+ private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
+ int limit = BYTE_SPECIES_128.loopBound(dimensions);
+ for (int iter = 0; iter < count; iter++) {
+ IntVector acc = IntVector.zero(INT_SPECIES_256);
+ long offset = in.getFilePointer();
+ int i = 0;
+ for (; i < limit; i += BYTE_SPECIES_64.length()) {
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // 32-bit multiply and add into accumulator
+ Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
+ Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
+ acc = acc.add(va32.mul(vb32));
+ }
+ in.seek(offset + limit);
+ // reduce
+ long res = acc.reduceLanes(ADD);
+ for (; i < dimensions; i++) {
+ res += in.readByte() * q[i];
+ }
+ scores[iter] = res;
+ }
+ }
+
+ private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
+ int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
+ for (int iter = 0; iter < count; iter++) {
+ IntVector acc = IntVector.zero(IntVector.SPECIES_128);
+ long offset = in.getFilePointer();
+ // 4 bytes at a time (re-loading half the vector each time!)
+ int i = 0;
+ for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
+ // load 8 bytes
+ ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+ ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+ // process first "half" only: 16-bit multiply
+ Vector va16 = va8.convert(B2S, 0);
+ Vector vb16 = vb8.convert(B2S, 0);
+ Vector prod16 = va16.mul(vb16);
+
+ // 32-bit add
+ acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
+ }
+ in.seek(offset + limit);
+ // reduce
+ long res = acc.reduceLanes(ADD);
+ for (; i < dimensions; i++) {
+ res += in.readByte() * q[i];
+ }
+ scores[iter] = res;
+ }
+ }
+
+ @Override
+ public void scoreBulk(
+ byte[] q,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ int7DotProductBulk(q, BULK_SIZE, scores);
+ applyCorrectionsBulk(
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores
+ );
+ }
+
+ private void applyCorrectionsBulk(
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
+ int i = 0;
+ long offset = in.getFilePointer();
+ float ay = queryLowerInterval;
+ float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
+ float y1 = queryComponentSum;
+ for (; i < limit; i += FLOAT_SPECIES.length()) {
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var lx = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES,
+ memorySegment,
+ offset + 4 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).sub(ax).mul(SEVEN_BIT_SCALE);
+ var targetComponentSums = IntVector.fromMemorySegment(
+ INT_SPECIES,
+ memorySegment,
+ offset + 8 * BULK_SIZE + i * Integer.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).convert(VectorOperators.I2F, 0);
+ var additionalCorrections = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES,
+ memorySegment,
+ offset + 12 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ );
+ var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
+ // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+ // qcDist;
+ var res1 = ax.mul(ay).mul(dimensions);
+ var res2 = lx.mul(ay).mul(targetComponentSums);
+ var res3 = ax.mul(ly).mul(y1);
+ var res4 = lx.mul(ly).mul(qcDist);
+ var res = res1.add(res2).add(res3).add(res4);
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
+ res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
+ res.intoArray(scores, i);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ res.intoArray(scores, i);
+ // not sure how to do it better
+ for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
+ scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+ }
+ } else {
+ res = res.add(1f).mul(0.5f).max(0);
+ res.intoArray(scores, i);
+ }
+ }
+ }
+ in.seek(offset + 16L * BULK_SIZE);
+ }
+}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index 4708a052b05db..856a0cf94410f 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -15,6 +15,7 @@
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import java.io.IOException;
import java.util.Locale;
@@ -42,6 +43,9 @@ public static ESVectorizationProvider getInstance() {
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
+ /** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
+ public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException;
+
// visible for tests
static ESVectorizationProvider lookup(boolean testMode) {
final int runtimeVersion = Runtime.version().feature();
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
index abb75352da2f7..9b798870a4284 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
@@ -13,6 +13,8 @@
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
+import org.elasticsearch.simdvec.internal.MemorySegmentES92Int7VectorsScorer;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
@@ -51,4 +53,16 @@ public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dime
}
return new ES91Int4VectorsScorer(input, dimension);
}
+
+ @Override
+ public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException {
+ if (input instanceof MemorySegmentAccessInput msai) {
+ MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
+ if (ms != null) {
+ return new MemorySegmentES92Int7VectorsScorer(input, dimension, ms);
+ }
+ }
+ return new ES92Int7VectorsScorer(input, dimension);
+
+ }
}
diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
new file mode 100644
index 0000000000000..1b60471b33b59
--- /dev/null
+++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal;
+
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorShape;
+import jdk.incubator.vector.VectorSpecies;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Native / panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
+public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer {
+
+ private static final VectorSpecies FLOAT_SPECIES;
+ private static final VectorSpecies INT_SPECIES;
+
+ static {
+ // default to platform supported bitsize
+ final int vectorBitSize = VectorShape.preferredShape().vectorBitSize();
+ FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(vectorBitSize));
+ INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(vectorBitSize));
+ }
+
+ private final MemorySegment memorySegment;
+
+ public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
+ super(in, dimensions);
+ this.memorySegment = memorySegment;
+ }
+
+ @Override
+ public long int7DotProduct(byte[] q) throws IOException {
+ final MemorySegment segment = memorySegment.asSlice(in.getFilePointer(), dimensions);
+ final MemorySegment querySegment = MemorySegment.ofArray(q);
+ final long res = Similarities.dotProduct7u(segment, querySegment, dimensions);
+ in.skipBytes(dimensions);
+ return res;
+ }
+
+ @Override
+ public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
+ // TODO: can we speed up bulks in native code?
+ for (int i = 0; i < count; i++) {
+ scores[i] = int7DotProduct(q);
+ }
+ }
+
+ @Override
+ public void scoreBulk(
+ byte[] q,
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ int7DotProductBulk(q, BULK_SIZE, scores);
+ applyCorrectionsBulk(
+ queryLowerInterval,
+ queryUpperInterval,
+ queryComponentSum,
+ queryAdditionalCorrection,
+ similarityFunction,
+ centroidDp,
+ scores
+ );
+ }
+
+ private void applyCorrectionsBulk(
+ float queryLowerInterval,
+ float queryUpperInterval,
+ int queryComponentSum,
+ float queryAdditionalCorrection,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
+ int i = 0;
+ long offset = in.getFilePointer();
+ float ay = queryLowerInterval;
+ float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
+ float y1 = queryComponentSum;
+ for (; i < limit; i += FLOAT_SPECIES.length()) {
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var lx = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES,
+ memorySegment,
+ offset + 4 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).sub(ax).mul(SEVEN_BIT_SCALE);
+ var targetComponentSums = IntVector.fromMemorySegment(
+ INT_SPECIES,
+ memorySegment,
+ offset + 8 * BULK_SIZE + i * Integer.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).convert(VectorOperators.I2F, 0);
+ var additionalCorrections = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES,
+ memorySegment,
+ offset + 12 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ );
+ var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
+ // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+ // qcDist;
+ var res1 = ax.mul(ay).mul(dimensions);
+ var res2 = lx.mul(ay).mul(targetComponentSums);
+ var res3 = ax.mul(ly).mul(y1);
+ var res4 = lx.mul(ly).mul(qcDist);
+ var res = res1.add(res2).add(res3).add(res4);
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
+ res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
+ res.intoArray(scores, i);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ res.intoArray(scores, i);
+ // not sure how to do it better
+ for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
+ scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+ }
+ } else {
+ res = res.add(1f).mul(0.5f).max(0);
+ res.intoArray(scores, i);
+ }
+ }
+ }
+ in.seek(offset + 16L * BULK_SIZE);
+ }
+}
diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java
new file mode 100644
index 0000000000000..31ef6092539e7
--- /dev/null
+++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java
@@ -0,0 +1,264 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
+import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.greaterThan;
+
+public class ES92Int7VectorScorerTests extends BaseVectorizationTests {
+
+ public void testInt7DotProduct() throws Exception {
+ // only even dimensions are supported
+ final int dimensions = random().nextInt(1, 1000) * 2;
+ final int numVectors = random().nextInt(1, 100);
+ final byte[] vector = new byte[dimensions];
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+ for (int i = 0; i < numVectors; i++) {
+ for (int j = 0; j < dimensions; j++) {
+ vector[j] = (byte) random().nextInt(128); // 7-bit quantization
+ }
+ out.writeBytes(vector, 0, dimensions);
+ }
+ }
+ final byte[] query = new byte[dimensions];
+ for (int j = 0; j < dimensions; j++) {
+ query[j] = (byte) random().nextInt(128); // 7-bit quantization
+ }
+ try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+ // Work on a slice that has just the right number of bytes to make the test fail with an
+ // index-out-of-bounds in case the implementation reads more than the allowed number of
+ // padding bytes.
+ final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors);
+ final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors);
+ final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(slice, dimensions);
+ final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice2, dimensions);
+ for (int i = 0; i < numVectors; i++) {
+ in.readBytes(vector, 0, dimensions);
+ long val = VectorUtil.dotProduct(vector, query);
+ assertEquals(val, defaultScorer.int7DotProduct(query));
+ assertEquals(val, panamaScorer.int7DotProduct(query));
+ assertEquals(in.getFilePointer(), slice.getFilePointer());
+ assertEquals(in.getFilePointer(), slice2.getFilePointer());
+ }
+ assertEquals((long) dimensions * numVectors, in.getFilePointer());
+ }
+ }
+ }
+
+ public void testInt7Score() throws Exception {
+ // only even dimensions are supported
+ final int dimensions = random().nextInt(1, 1000) * 2;
+ final int numVectors = random().nextInt(1, 100);
+
+ float[][] vectors = new float[numVectors][dimensions];
+ final int[] scratch = new int[dimensions];
+ final byte[] qVector = new byte[dimensions];
+ final float[] centroid = new float[dimensions];
+ VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
+ randomVector(centroid, similarityFunction);
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+ for (float[] vector : vectors) {
+ randomVector(vector, similarityFunction);
+ OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
+ vector.clone(),
+ scratch,
+ (byte) 7,
+ centroid
+ );
+ for (int j = 0; j < dimensions; j++) {
+ qVector[j] = (byte) scratch[j];
+ }
+ out.writeBytes(qVector, 0, dimensions);
+ out.writeInt(Float.floatToIntBits(result.lowerInterval()));
+ out.writeInt(Float.floatToIntBits(result.upperInterval()));
+ out.writeInt(Float.floatToIntBits(result.additionalCorrection()));
+ out.writeInt(result.quantizedComponentSum());
+ }
+ }
+ final float[] query = new float[dimensions];
+ randomVector(query, similarityFunction);
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
+ query.clone(),
+ scratch,
+ (byte) 7,
+ centroid
+ );
+ byte[] qQuery = new byte[dimensions];
+ for (int i = 0; i < dimensions; i++) {
+ qQuery[i] = (byte) scratch[i];
+ }
+
+ float centroidDp = VectorUtil.dotProduct(centroid, centroid);
+
+ try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+ // Work on a slice that has just the right number of bytes to make the test fail with an
+ // index-out-of-bounds in case the implementation reads more than the allowed number of
+ // padding bytes.
+ final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors);
+ final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions);
+ final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions);
+ for (int i = 0; i < numVectors; i++) {
+ float scoreDefault = defaultScorer.score(
+ qQuery,
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ similarityFunction,
+ centroidDp
+ );
+ float scorePanama = panamaScorer.score(
+ qQuery,
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ similarityFunction,
+ centroidDp
+ );
+ assertEquals(scoreDefault, scorePanama, 0.001f);
+ float realSimilarity = similarityFunction.compare(vectors[i], query);
+ float accuracy = realSimilarity > scoreDefault ? scoreDefault / realSimilarity : realSimilarity / scoreDefault;
+ assertThat(accuracy, greaterThan(0.98f));
+ assertEquals(in.getFilePointer(), slice.getFilePointer());
+ }
+ assertEquals((long) (dimensions + 16) * numVectors, in.getFilePointer());
+ }
+ }
+ }
+
+ public void testInt7ScoreBulk() throws Exception {
+ // only even dimensions are supported
+ final int dimensions = random().nextInt(1, 1000) * 2;
+ final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE;
+ final float[][] vectors = new float[numVectors][dimensions];
+ final int[] quantizedScratch = new int[dimensions];
+ final byte[] quantizeVector = new byte[dimensions];
+ final float[] centroid = new float[dimensions];
+ VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
+ randomVector(centroid, similarityFunction);
+
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+ OptimizedScalarQuantizer.QuantizationResult[] results =
+ new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];
+ for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
+ for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
+ randomVector(vectors[i + j], similarityFunction);
+ results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 7, centroid);
+ for (int k = 0; k < dimensions; k++) {
+ quantizeVector[k] = (byte) quantizedScratch[k];
+ }
+ out.writeBytes(quantizeVector, 0, dimensions);
+ }
+ writeCorrections(results, out);
+ }
+ }
+ final float[] query = new float[dimensions];
+ final byte[] quantizeQuery = new byte[dimensions];
+ randomVector(query, similarityFunction);
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
+ query.clone(),
+ quantizedScratch,
+ (byte) 7,
+ centroid
+ );
+ for (int j = 0; j < dimensions; j++) {
+ quantizeQuery[j] = (byte) quantizedScratch[j];
+ }
+ float centroidDp = VectorUtil.dotProduct(centroid, centroid);
+
+ try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+ // Work on a slice that has just the right number of bytes to make the test fail with an
+ // index-out-of-bounds in case the implementation reads more than the allowed number of
+ // padding bytes.
+ final IndexInput slice = in.slice("test", 0, (long) (dimensions + 16) * numVectors);
+ final ES92Int7VectorsScorer defaultScorer = defaultProvider().newES92Int7VectorsScorer(in, dimensions);
+ final ES92Int7VectorsScorer panamaScorer = maybePanamaProvider().newES92Int7VectorsScorer(slice, dimensions);
+ float[] scoresDefault = new float[ES91Int4VectorsScorer.BULK_SIZE];
+ float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE];
+ for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
+ defaultScorer.scoreBulk(
+ quantizeQuery,
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ similarityFunction,
+ centroidDp,
+ scoresDefault
+ );
+ panamaScorer.scoreBulk(
+ quantizeQuery,
+ queryCorrections.lowerInterval(),
+ queryCorrections.upperInterval(),
+ queryCorrections.quantizedComponentSum(),
+ queryCorrections.additionalCorrection(),
+ similarityFunction,
+ centroidDp,
+ scoresPanama
+ );
+ for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+ assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
+ float realSimilarity = similarityFunction.compare(vectors[i + j], query);
+ float accuracy = realSimilarity > scoresDefault[j]
+ ? scoresDefault[j] / realSimilarity
+ : realSimilarity / scoresDefault[j];
+ assertThat(accuracy, greaterThan(0.98f));
+ }
+ assertEquals(in.getFilePointer(), slice.getFilePointer());
+ }
+ assertEquals((long) (dimensions + 16) * numVectors, in.getFilePointer());
+ }
+ }
+ }
+
+ private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
+ for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+ out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
+ }
+ for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+ out.writeInt(Float.floatToIntBits(correction.upperInterval()));
+ }
+ for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+ int targetComponentSum = correction.quantizedComponentSum();
+ out.writeInt(targetComponentSum);
+ }
+ for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+ out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
+ }
+ }
+
+ private void randomVector(float[] vector, VectorSimilarityFunction vectorSimilarityFunction) {
+ for (int i = 0; i < vector.length; i++) {
+ vector[i] = random().nextFloat();
+ }
+ if (vectorSimilarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
+ VectorUtil.l2normalize(vector);
+ }
+ }
+}