Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
import org.elasticsearch.script.field.vectors.BitBinaryDenseVector;
import org.elasticsearch.script.field.vectors.BitKnnDenseVector;
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
import org.elasticsearch.script.field.vectors.DenseVector;
Expand All @@ -37,30 +39,30 @@
import java.util.function.DoubleSupplier;

/**
* Various benchmarks for the distance functions
* used by indexed and non-indexed vectors.
* Parameters include element, dims, function, and type.
* Various benchmarks for the distance functions used by indexed and non-indexed vectors.
* Parameters include doc and query type, dims, function, and implementation.
* For individual local tests it may be useful to increase
* fork, measurement, and operations per invocation. (Note
* to also update the benchmark loop if operations per invocation
* is increased.)
* fork, measurement, and operations per invocation.
*/
@Fork(1)
@Warmup(iterations = 1)
@Measurement(iterations = 2)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@OperationsPerInvocation(25000)
@OperationsPerInvocation(DistanceFunctionBenchmark.OPERATIONS)
@State(Scope.Benchmark)
public class DistanceFunctionBenchmark {

public static final int OPERATIONS = 25000;

static {
LogConfigurator.configureESLogging();
}

public enum VectorType {
FLOAT,
BYTE
BYTE,
BIT
}

public enum Function {
Expand Down Expand Up @@ -122,7 +124,7 @@ private static BytesRef generateVectorData(float[] vector) {
}

private static BytesRef generateVectorData(float[] vector, float mag) {
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4);
ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES + Float.BYTES);
for (float f : vector) {
buffer.putFloat(f);
}
Expand All @@ -133,24 +135,29 @@ private static BytesRef generateVectorData(float[] vector, float mag) {
private static BytesRef generateVectorData(byte[] vector) {
float mag = calculateMag(vector);

ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4);
ByteBuffer buffer = ByteBuffer.allocate(vector.length + Float.BYTES);
buffer.put(vector);
buffer.putFloat(mag);
return new BytesRef(buffer.array());
}

@Setup
public void findBenchmarkImpl() {
if (dims % 8 != 0) throw new IllegalArgumentException("Dims must be a multiple of 8");
Random r = new Random();

float[] floatDocVector = new float[dims];
byte[] byteDocVector = new byte[dims];
byte[] bitDocVector = new byte[dims / 8];

float[] floatQueryVector = new float[dims];
byte[] byteQueryVector = new byte[dims];
byte[] bitQueryVector = new byte[dims / 8];

r.nextBytes(byteDocVector);
r.nextBytes(bitDocVector);
r.nextBytes(byteQueryVector);
r.nextBytes(bitQueryVector);
for (int i = 0; i < dims; i++) {
floatDocVector[i] = r.nextFloat();
floatQueryVector[i] = r.nextFloat();
Expand Down Expand Up @@ -179,10 +186,11 @@ public void findBenchmarkImpl() {
};
case BYTE -> switch (type) {
case KNN -> new ByteKnnDenseVector(byteDocVector);
case BINARY -> {
BytesRef vectorData = generateVectorData(byteDocVector);
yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
}
case BINARY -> new ByteBinaryDenseVector(byteDocVector, generateVectorData(byteDocVector), dims);
};
case BIT -> switch (type) {
case KNN -> new BitKnnDenseVector(bitDocVector);
case BINARY -> new BitBinaryDenseVector(bitDocVector, new BytesRef(bitDocVector), bitDocVector.length);
};
};

Expand All @@ -204,21 +212,28 @@ public void findBenchmarkImpl() {
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
};
case BIT -> switch (function) {
case DOT -> () -> vectorImpl.dotProduct(bitQueryVector);
case COSINE -> throw new UnsupportedOperationException("Unsupported function " + function);
case L1 -> () -> vectorImpl.l1Norm(bitQueryVector);
case L2 -> () -> vectorImpl.l2Norm(bitQueryVector);
case HAMMING -> () -> vectorImpl.hamming(bitQueryVector);
};
};
}

@Fork(1)
@Benchmark
public void benchmark(Blackhole blackhole) {
for (int i = 0; i < 25000; ++i) {
for (int i = 0; i < OPERATIONS; ++i) {
blackhole.consume(benchmarkImpl.getAsDouble());
}
}

@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
@Benchmark
public void vectorBenchmark(Blackhole blackhole) {
for (int i = 0; i < 25000; ++i) {
for (int i = 0; i < OPERATIONS; ++i) {
blackhole.consume(benchmarkImpl.getAsDouble());
}
}
Expand Down
6 changes: 6 additions & 0 deletions docs/changelog/124722.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 124722
summary: Add panama implementations of byte-bit and float-bit script operations
area: Vector Search
type: enhancement
issues:
- 117096
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ public float ipFloatByte(float[] q, byte[] d) {
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
return ipByteBitImpl(q, d, 0);
}

public static int ipByteBitImpl(byte[] q, byte[] d, int start) {
assert q.length == d.length * Byte.SIZE;
int acc0 = 0;
int acc1 = 0;
int acc2 = 0;
int acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
for (int i = start; i < d.length; i++) {
byte mask = d[i];
// Make sure its just 1 or 0

Expand All @@ -69,13 +73,17 @@ public static int ipByteBitImpl(byte[] q, byte[] d) {
}

public static float ipFloatBitImpl(float[] q, byte[] d) {
return ipFloatBitImpl(q, d, 0);
}

static float ipFloatBitImpl(float[] q, byte[] d, int start) {
assert q.length == d.length * Byte.SIZE;
float acc0 = 0;
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
for (int i = start; i < d.length; i++) {
byte mask = d[i];
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
Expand Down
Loading
Loading