Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.IOException;
import java.nio.file.Files;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 4, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 5, time = 1)
// engage some noise reduction
@Fork(value = 1)
public class Int7ScorerBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "384", "782", "1024" })
int dims;

int numVectors = 20 * ES92Int7VectorsScorer.BULK_SIZE;
int numQueries = 5;

byte[] scratch;
byte[][] binaryVectors;
byte[][] binaryQueries;
float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];

ES92Int7VectorsScorer scorer;
Directory dir;
IndexInput in;

OptimizedScalarQuantizer.QuantizationResult queryCorrections;
float centroidDp;

@Setup
public void setup() throws IOException {
binaryVectors = new byte[numVectors][dims];
dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
for (byte[] binaryVector : binaryVectors) {
for (int i = 0; i < dims; i++) {
// 4-bit quantization
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
}
out.writeBytes(binaryVector, 0, binaryVector.length);
ThreadLocalRandom.current().nextBytes(binaryVector);
out.writeBytes(binaryVector, 0, 16); // corrections
}
}

queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
ThreadLocalRandom.current().nextFloat(),
ThreadLocalRandom.current().nextFloat(),
ThreadLocalRandom.current().nextFloat(),
Short.toUnsignedInt((short) ThreadLocalRandom.current().nextInt())
);
centroidDp = ThreadLocalRandom.current().nextFloat();

in = dir.openInput("vectors", IOContext.DEFAULT);
binaryQueries = new byte[numVectors][dims];
for (byte[] binaryVector : binaryVectors) {
for (int i = 0; i < dims; i++) {
// 7-bit quantization
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(128);
}
}

scratch = new byte[dims];
scorer = ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(in, dims);
}

@TearDown
public void teardown() throws IOException {
IOUtils.close(dir, in);
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegment(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i++) {
bh.consume(
scorer.score(
binaryQueries[j],
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp
)
);
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void scoreFromMemorySegmentBulk(Blackhole bh) throws IOException {
for (int j = 0; j < numQueries; j++) {
in.seek(0);
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
scorer.scoreBulk(
binaryQueries[j],
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
VectorSimilarityFunction.EUCLIDEAN,
centroidDp,
scores
);
for (float score : scores) {
bh.consume(score);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;

import java.io.IOException;

import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

/**
* Scorer for 7 bit quantized vectors stored in a {@link IndexInput}.
* Queries are expected to be quantized using 7 bits as well.
* */
public class ES92Int7VectorsScorer {

public static final int BULK_SIZE = 16;
protected static final float SEVEN_BIT_SCALE = 1f / ((1 << 7) - 1);

/** The wrapper {@link IndexInput}. */
protected final IndexInput in;
protected final int dimensions;

private final float[] lowerIntervals = new float[BULK_SIZE];
private final float[] upperIntervals = new float[BULK_SIZE];
private final int[] targetComponentSums = new int[BULK_SIZE];
private final float[] additionalCorrections = new float[BULK_SIZE];

/** Sole constructor, called by sub-classes. */
public ES92Int7VectorsScorer(IndexInput in, int dimensions) {
this.in = in;
this.dimensions = dimensions;
}

/**
* compute the quantize distance between the provided quantized query and the quantized vector
* that is read from the wrapped {@link IndexInput}.
*/
public long int7DotProduct(byte[] b) throws IOException {
int total = 0;
for (int i = 0; i < dimensions; i++) {
total += in.readByte() * b[i];
}
return total;
}

/**
* compute the quantize distance between the provided quantized query and the quantized vectors
* that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
* determined by {code count} and the results are stored in the provided {@code scores} array.
*/
public void int7DotProductBulk(byte[] b, int count, float[] scores) throws IOException {
for (int i = 0; i < count; i++) {
scores[i] = int7DotProduct(b);
}
}

/**
* Computes the score by applying the necessary corrections to the provided quantized distance.
*/
public float score(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp
) throws IOException {
float score = int7DotProduct(q);
in.readFloats(lowerIntervals, 0, 3);
int addition = in.readInt();
return applyCorrections(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
lowerIntervals[0],
lowerIntervals[1],
addition,
lowerIntervals[2],
score
);
}

/**
* compute the distance between the provided quantized query and the quantized vectors that are
* read from the wrapped {@link IndexInput}.
*
* <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
* input is as follows: First the quantized vectors are read from the input,then all the lower
* intervals as floats, then all the upper intervals as floats, then all the target component sums
* as shorts, and finally all the additional corrections as floats.
*
* <p>The results are stored in the provided scores array.
*/
public void scoreBulk(
byte[] q,
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float[] scores
) throws IOException {
int7DotProductBulk(q, BULK_SIZE, scores);
in.readFloats(lowerIntervals, 0, BULK_SIZE);
in.readFloats(upperIntervals, 0, BULK_SIZE);
in.readInts(targetComponentSums, 0, BULK_SIZE);
in.readFloats(additionalCorrections, 0, BULK_SIZE);
for (int i = 0; i < BULK_SIZE; i++) {
scores[i] = applyCorrections(
queryLowerInterval,
queryUpperInterval,
queryComponentSum,
queryAdditionalCorrection,
similarityFunction,
centroidDp,
lowerIntervals[i],
upperIntervals[i],
targetComponentSums[i],
additionalCorrections[i],
scores[i]
);
}
}

/**
* Computes the score by applying the necessary corrections to the provided quantized distance.
*/
public float applyCorrections(
float queryLowerInterval,
float queryUpperInterval,
int queryComponentSum,
float queryAdditionalCorrection,
VectorSimilarityFunction similarityFunction,
float centroidDp,
float lowerInterval,
float upperInterval,
int targetComponentSum,
float additionalCorrection,
float qcDist
) {
float ax = lowerInterval;
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (upperInterval - ax) * SEVEN_BIT_SCALE;
float ay = queryLowerInterval;
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
float y1 = queryComponentSum;
float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
score = queryAdditionalCorrection + additionalCorrection - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryAdditionalCorrection + additionalCorrection - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, i
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
}

public static ES92Int7VectorsScorer getES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException {
return ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(input, dimension);
}

public static long ipByteBinByte(byte[] q, byte[] d) {
if (q.length != d.length * B_QUERY) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;

final class DefaultESVectorizationProvider extends ESVectorizationProvider {
private final ESVectorUtilSupport vectorUtilSupport;
Expand All @@ -28,12 +27,17 @@ public ESVectorUtilSupport getVectorUtilSupport() {
}

@Override
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) {
return new ES91OSQVectorsScorer(input, dimension);
}

@Override
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) {
return new ES91Int4VectorsScorer(input, dimension);
}

@Override
public ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) {
return new ES92Int7VectorsScorer(input, dimension);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -35,6 +36,9 @@ public static ESVectorizationProvider getInstance() {
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;

/** Create a new {@link ES92Int7VectorsScorer} for the given {@link IndexInput}. */
public abstract ES92Int7VectorsScorer newES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException;

// visible for tests
static ESVectorizationProvider lookup(boolean testMode) {
return new DefaultESVectorizationProvider();
Expand Down
Loading