Skip to content

Commit 1493794

Browse files
committed
Add bit vectors to DistanceFunctionBenchmark
1 parent d7864f4 commit 1493794

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.common.logging.LogConfigurator;
1414
import org.elasticsearch.index.IndexVersion;
1515
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
16+
import org.elasticsearch.script.field.vectors.BitBinaryDenseVector;
17+
import org.elasticsearch.script.field.vectors.BitKnnDenseVector;
1618
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
1719
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
1820
import org.elasticsearch.script.field.vectors.DenseVector;
@@ -60,7 +62,8 @@ public class DistanceFunctionBenchmark {
6062

6163
public enum VectorType {
6264
FLOAT,
63-
BYTE
65+
BYTE,
66+
BIT
6467
}
6568

6669
public enum Function {
@@ -145,12 +148,16 @@ public void findBenchmarkImpl() {
145148

146149
float[] floatDocVector = new float[dims];
147150
byte[] byteDocVector = new byte[dims];
151+
byte[] bitDocVector = new byte[dims/8];
148152

149153
float[] floatQueryVector = new float[dims];
150154
byte[] byteQueryVector = new byte[dims];
155+
byte[] bitQueryVector = new byte[dims/8];
151156

152157
r.nextBytes(byteDocVector);
158+
r.nextBytes(bitDocVector);
153159
r.nextBytes(byteQueryVector);
160+
r.nextBytes(bitQueryVector);
154161
for (int i = 0; i < dims; i++) {
155162
floatDocVector[i] = r.nextFloat();
156163
floatQueryVector[i] = r.nextFloat();
@@ -179,10 +186,11 @@ public void findBenchmarkImpl() {
179186
};
180187
case BYTE -> switch (type) {
181188
case KNN -> new ByteKnnDenseVector(byteDocVector);
182-
case BINARY -> {
183-
BytesRef vectorData = generateVectorData(byteDocVector);
184-
yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
185-
}
189+
case BINARY -> new ByteBinaryDenseVector(byteDocVector, generateVectorData(byteDocVector), dims);
190+
};
191+
case BIT -> switch (type) {
192+
case KNN -> new BitKnnDenseVector(bitDocVector);
193+
case BINARY -> new BitBinaryDenseVector(bitDocVector, new BytesRef(bitDocVector), bitDocVector.length);
186194
};
187195
};
188196

@@ -204,6 +212,13 @@ public void findBenchmarkImpl() {
204212
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
205213
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
206214
};
215+
case BIT -> switch (function) {
216+
case DOT -> () -> vectorImpl.dotProduct(bitQueryVector);
217+
case COSINE -> throw new UnsupportedOperationException("Unsupported function " + function);
218+
case L1 -> () -> vectorImpl.l1Norm(bitQueryVector);
219+
case L2 -> () -> vectorImpl.l2Norm(bitQueryVector);
220+
case HAMMING -> () -> vectorImpl.hamming(bitQueryVector);
221+
};
207222
};
208223
}
209224

0 commit comments

Comments
 (0)