1313import  org .elasticsearch .common .logging .LogConfigurator ;
1414import  org .elasticsearch .index .IndexVersion ;
1515import  org .elasticsearch .script .field .vectors .BinaryDenseVector ;
16+ import  org .elasticsearch .script .field .vectors .BitBinaryDenseVector ;
17+ import  org .elasticsearch .script .field .vectors .BitKnnDenseVector ;
1618import  org .elasticsearch .script .field .vectors .ByteBinaryDenseVector ;
1719import  org .elasticsearch .script .field .vectors .ByteKnnDenseVector ;
1820import  org .elasticsearch .script .field .vectors .DenseVector ;
3739import  java .util .function .DoubleSupplier ;
3840
3941/** 
40-  * Various benchmarks for the distance functions 
41-  * used by indexed and non-indexed vectors. 
42-  * Parameters include element, dims, function, and type. 
42+  * Various benchmarks for the distance functions used by indexed and non-indexed vectors. 
43+  * Parameters include doc and query type, dims, function, and implementation. 
4344 * For individual local tests it may be useful to increase 
44-  * fork, measurement, and operations per invocation. (Note 
45-  * to also update the benchmark loop if operations per invocation 
46-  * is increased.) 
45+  * fork, measurement, and operations per invocation. 
4746 */ 
4847@ Fork (1 )
4948@ Warmup (iterations  = 1 )
5049@ Measurement (iterations  = 2 )
5150@ BenchmarkMode (Mode .AverageTime )
5251@ OutputTimeUnit (TimeUnit .NANOSECONDS )
53- @ OperationsPerInvocation (25000 )
52+ @ OperationsPerInvocation (DistanceFunctionBenchmark . OPERATIONS )
5453@ State (Scope .Benchmark )
5554public  class  DistanceFunctionBenchmark  {
5655
56+     public  static  final  int  OPERATIONS  = 25000 ;
57+ 
5758    static  {
5859        LogConfigurator .configureESLogging ();
5960    }
6061
6162    public  enum  VectorType  {
6263        FLOAT ,
63-         BYTE 
64+         BYTE ,
65+         BIT 
6466    }
6567
6668    public  enum  Function  {
@@ -122,7 +124,7 @@ private static BytesRef generateVectorData(float[] vector) {
122124    }
123125
124126    private  static  BytesRef  generateVectorData (float [] vector , float  mag ) {
125-         ByteBuffer  buffer  = ByteBuffer .allocate (vector .length  * 4  + 4 );
127+         ByteBuffer  buffer  = ByteBuffer .allocate (vector .length  * Float . BYTES  + Float . BYTES );
126128        for  (float  f  : vector ) {
127129            buffer .putFloat (f );
128130        }
@@ -133,24 +135,29 @@ private static BytesRef generateVectorData(float[] vector, float mag) {
133135    private  static  BytesRef  generateVectorData (byte [] vector ) {
134136        float  mag  = calculateMag (vector );
135137
136-         ByteBuffer  buffer  = ByteBuffer .allocate (vector .length  + 4 );
138+         ByteBuffer  buffer  = ByteBuffer .allocate (vector .length  + Float . BYTES );
137139        buffer .put (vector );
138140        buffer .putFloat (mag );
139141        return  new  BytesRef (buffer .array ());
140142    }
141143
142144    @ Setup 
143145    public  void  findBenchmarkImpl () {
146+         if  (dims  % 8  != 0 ) throw  new  IllegalArgumentException ("Dims must be a multiple of 8" );
144147        Random  r  = new  Random ();
145148
146149        float [] floatDocVector  = new  float [dims ];
147150        byte [] byteDocVector  = new  byte [dims ];
151+         byte [] bitDocVector  = new  byte [dims  / 8 ];
148152
149153        float [] floatQueryVector  = new  float [dims ];
150154        byte [] byteQueryVector  = new  byte [dims ];
155+         byte [] bitQueryVector  = new  byte [dims  / 8 ];
151156
152157        r .nextBytes (byteDocVector );
158+         r .nextBytes (bitDocVector );
153159        r .nextBytes (byteQueryVector );
160+         r .nextBytes (bitQueryVector );
154161        for  (int  i  = 0 ; i  < dims ; i ++) {
155162            floatDocVector [i ] = r .nextFloat ();
156163            floatQueryVector [i ] = r .nextFloat ();
@@ -179,10 +186,11 @@ public void findBenchmarkImpl() {
179186            };
180187            case  BYTE  -> switch  (type ) {
181188                case  KNN  -> new  ByteKnnDenseVector (byteDocVector );
182-                 case  BINARY  -> {
183-                     BytesRef  vectorData  = generateVectorData (byteDocVector );
184-                     yield new  ByteBinaryDenseVector (byteDocVector , vectorData , dims );
185-                 }
189+                 case  BINARY  -> new  ByteBinaryDenseVector (byteDocVector , generateVectorData (byteDocVector ), dims );
190+             };
191+             case  BIT  -> switch  (type ) {
192+                 case  KNN  -> new  BitKnnDenseVector (bitDocVector );
193+                 case  BINARY  -> new  BitBinaryDenseVector (bitDocVector , new  BytesRef (bitDocVector ), bitDocVector .length );
186194            };
187195        };
188196
@@ -204,21 +212,28 @@ public void findBenchmarkImpl() {
204212                case  L2  -> () -> vectorImpl .l2Norm (byteQueryVector );
205213                case  HAMMING  -> () -> vectorImpl .hamming (byteQueryVector );
206214            };
215+             case  BIT  -> switch  (function ) {
216+                 case  DOT  -> () -> vectorImpl .dotProduct (bitQueryVector );
217+                 case  COSINE  -> throw  new  UnsupportedOperationException ("Unsupported function "  + function );
218+                 case  L1  -> () -> vectorImpl .l1Norm (bitQueryVector );
219+                 case  L2  -> () -> vectorImpl .l2Norm (bitQueryVector );
220+                 case  HAMMING  -> () -> vectorImpl .hamming (bitQueryVector );
221+             };
207222        };
208223    }
209224
210225    @ Fork (1 )
211226    @ Benchmark 
212227    public  void  benchmark (Blackhole  blackhole ) {
213-         for  (int  i  = 0 ; i  < 25000 ; ++i ) {
228+         for  (int  i  = 0 ; i  < OPERATIONS ; ++i ) {
214229            blackhole .consume (benchmarkImpl .getAsDouble ());
215230        }
216231    }
217232
218233    @ Fork (value  = 1 , jvmArgsPrepend  = { "--add-modules=jdk.incubator.vector"  })
219234    @ Benchmark 
220235    public  void  vectorBenchmark (Blackhole  blackhole ) {
221-         for  (int  i  = 0 ; i  < 25000 ; ++i ) {
236+         for  (int  i  = 0 ; i  < OPERATIONS ; ++i ) {
222237            blackhole .consume (benchmarkImpl .getAsDouble ());
223238        }
224239    }
0 commit comments