1313import org .elasticsearch .common .logging .LogConfigurator ;
1414import org .elasticsearch .index .IndexVersion ;
1515import org .elasticsearch .script .field .vectors .BinaryDenseVector ;
16+ import org .elasticsearch .script .field .vectors .BitBinaryDenseVector ;
17+ import org .elasticsearch .script .field .vectors .BitKnnDenseVector ;
1618import org .elasticsearch .script .field .vectors .ByteBinaryDenseVector ;
1719import org .elasticsearch .script .field .vectors .ByteKnnDenseVector ;
1820import org .elasticsearch .script .field .vectors .DenseVector ;
3739import java .util .function .DoubleSupplier ;
3840
3941/**
40- * Various benchmarks for the distance functions
41- * used by indexed and non-indexed vectors.
42- * Parameters include element, dims, function, and type.
42+ * Various benchmarks for the distance functions used by indexed and non-indexed vectors.
43+ * Parameters include doc and query type, dims, function, and implementation.
4344 * For individual local tests it may be useful to increase
44- * fork, measurement, and operations per invocation. (Note
45- * to also update the benchmark loop if operations per invocation
46- * is increased.)
45+ * fork, measurement, and operations per invocation.
4746 */
4847@ Fork (1 )
4948@ Warmup (iterations = 1 )
5049@ Measurement (iterations = 2 )
5150@ BenchmarkMode (Mode .AverageTime )
5251@ OutputTimeUnit (TimeUnit .NANOSECONDS )
53- @ OperationsPerInvocation (25000 )
52+ @ OperationsPerInvocation (DistanceFunctionBenchmark . OPERATIONS )
5453@ State (Scope .Benchmark )
5554public class DistanceFunctionBenchmark {
5655
56+ public static final int OPERATIONS = 25000 ;
57+
5758 static {
5859 LogConfigurator .configureESLogging ();
5960 }
6061
6162 public enum VectorType {
6263 FLOAT ,
63- BYTE
64+ BYTE ,
65+ BIT
6466 }
6567
6668 public enum Function {
@@ -122,7 +124,7 @@ private static BytesRef generateVectorData(float[] vector) {
122124 }
123125
124126 private static BytesRef generateVectorData (float [] vector , float mag ) {
125- ByteBuffer buffer = ByteBuffer .allocate (vector .length * 4 + 4 );
127+ ByteBuffer buffer = ByteBuffer .allocate (vector .length * Float . BYTES + Float . BYTES );
126128 for (float f : vector ) {
127129 buffer .putFloat (f );
128130 }
@@ -133,24 +135,29 @@ private static BytesRef generateVectorData(float[] vector, float mag) {
133135 private static BytesRef generateVectorData (byte [] vector ) {
134136 float mag = calculateMag (vector );
135137
136- ByteBuffer buffer = ByteBuffer .allocate (vector .length + 4 );
138+ ByteBuffer buffer = ByteBuffer .allocate (vector .length + Float . BYTES );
137139 buffer .put (vector );
138140 buffer .putFloat (mag );
139141 return new BytesRef (buffer .array ());
140142 }
141143
142144 @ Setup
143145 public void findBenchmarkImpl () {
146+ if (dims % 8 != 0 ) throw new IllegalArgumentException ("Dims must be a multiple of 8" );
144147 Random r = new Random ();
145148
146149 float [] floatDocVector = new float [dims ];
147150 byte [] byteDocVector = new byte [dims ];
151+ byte [] bitDocVector = new byte [dims / 8 ];
148152
149153 float [] floatQueryVector = new float [dims ];
150154 byte [] byteQueryVector = new byte [dims ];
155+ byte [] bitQueryVector = new byte [dims / 8 ];
151156
152157 r .nextBytes (byteDocVector );
158+ r .nextBytes (bitDocVector );
153159 r .nextBytes (byteQueryVector );
160+ r .nextBytes (bitQueryVector );
154161 for (int i = 0 ; i < dims ; i ++) {
155162 floatDocVector [i ] = r .nextFloat ();
156163 floatQueryVector [i ] = r .nextFloat ();
@@ -179,10 +186,11 @@ public void findBenchmarkImpl() {
179186 };
180187 case BYTE -> switch (type ) {
181188 case KNN -> new ByteKnnDenseVector (byteDocVector );
182- case BINARY -> {
183- BytesRef vectorData = generateVectorData (byteDocVector );
184- yield new ByteBinaryDenseVector (byteDocVector , vectorData , dims );
185- }
189+ case BINARY -> new ByteBinaryDenseVector (byteDocVector , generateVectorData (byteDocVector ), dims );
190+ };
191+ case BIT -> switch (type ) {
192+ case KNN -> new BitKnnDenseVector (bitDocVector );
193+ case BINARY -> new BitBinaryDenseVector (bitDocVector , new BytesRef (bitDocVector ), bitDocVector .length );
186194 };
187195 };
188196
@@ -204,21 +212,28 @@ public void findBenchmarkImpl() {
204212 case L2 -> () -> vectorImpl .l2Norm (byteQueryVector );
205213 case HAMMING -> () -> vectorImpl .hamming (byteQueryVector );
206214 };
215+ case BIT -> switch (function ) {
216+ case DOT -> () -> vectorImpl .dotProduct (bitQueryVector );
217+ case COSINE -> throw new UnsupportedOperationException ("Unsupported function " + function );
218+ case L1 -> () -> vectorImpl .l1Norm (bitQueryVector );
219+ case L2 -> () -> vectorImpl .l2Norm (bitQueryVector );
220+ case HAMMING -> () -> vectorImpl .hamming (bitQueryVector );
221+ };
207222 };
208223 }
209224
210225 @ Fork (1 )
211226 @ Benchmark
212227 public void benchmark (Blackhole blackhole ) {
213- for (int i = 0 ; i < 25000 ; ++i ) {
228+ for (int i = 0 ; i < OPERATIONS ; ++i ) {
214229 blackhole .consume (benchmarkImpl .getAsDouble ());
215230 }
216231 }
217232
218233 @ Fork (value = 1 , jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
219234 @ Benchmark
220235 public void vectorBenchmark (Blackhole blackhole ) {
221- for (int i = 0 ; i < 25000 ; ++i ) {
236+ for (int i = 0 ; i < OPERATIONS ; ++i ) {
222237 blackhole .consume (benchmarkImpl .getAsDouble ());
223238 }
224239 }
0 commit comments