1313import org .elasticsearch .common .logging .LogConfigurator ;
1414import org .elasticsearch .index .IndexVersion ;
1515import org .elasticsearch .script .field .vectors .BinaryDenseVector ;
16+ import org .elasticsearch .script .field .vectors .BitBinaryDenseVector ;
17+ import org .elasticsearch .script .field .vectors .BitKnnDenseVector ;
1618import org .elasticsearch .script .field .vectors .ByteBinaryDenseVector ;
1719import org .elasticsearch .script .field .vectors .ByteKnnDenseVector ;
1820import org .elasticsearch .script .field .vectors .DenseVector ;
@@ -60,7 +62,8 @@ public class DistanceFunctionBenchmark {
6062
6163 public enum VectorType {
6264 FLOAT ,
63- BYTE
65+ BYTE ,
66+ BIT
6467 }
6568
6669 public enum Function {
@@ -145,12 +148,16 @@ public void findBenchmarkImpl() {
145148
146149 float [] floatDocVector = new float [dims ];
147150 byte [] byteDocVector = new byte [dims ];
151+ byte [] bitDocVector = new byte [dims /8 ];
148152
149153 float [] floatQueryVector = new float [dims ];
150154 byte [] byteQueryVector = new byte [dims ];
155+ byte [] bitQueryVector = new byte [dims /8 ];
151156
152157 r .nextBytes (byteDocVector );
158+ r .nextBytes (bitDocVector );
153159 r .nextBytes (byteQueryVector );
160+ r .nextBytes (bitQueryVector );
154161 for (int i = 0 ; i < dims ; i ++) {
155162 floatDocVector [i ] = r .nextFloat ();
156163 floatQueryVector [i ] = r .nextFloat ();
@@ -179,10 +186,11 @@ public void findBenchmarkImpl() {
179186 };
180187 case BYTE -> switch (type ) {
181188 case KNN -> new ByteKnnDenseVector (byteDocVector );
182- case BINARY -> {
183- BytesRef vectorData = generateVectorData (byteDocVector );
184- yield new ByteBinaryDenseVector (byteDocVector , vectorData , dims );
185- }
189+ case BINARY -> new ByteBinaryDenseVector (byteDocVector , generateVectorData (byteDocVector ), dims );
190+ };
191+ case BIT -> switch (type ) {
192+ case KNN -> new BitKnnDenseVector (bitDocVector );
193+ case BINARY -> new BitBinaryDenseVector (bitDocVector , new BytesRef (bitDocVector ), bitDocVector .length );
186194 };
187195 };
188196
@@ -204,6 +212,13 @@ public void findBenchmarkImpl() {
204212 case L2 -> () -> vectorImpl .l2Norm (byteQueryVector );
205213 case HAMMING -> () -> vectorImpl .hamming (byteQueryVector );
206214 };
215+ case BIT -> switch (function ) {
216+ case DOT -> () -> vectorImpl .dotProduct (bitQueryVector );
217+ case COSINE -> throw new UnsupportedOperationException ("Unsupported function " + function );
218+ case L1 -> () -> vectorImpl .l1Norm (bitQueryVector );
219+ case L2 -> () -> vectorImpl .l2Norm (bitQueryVector );
220+ case HAMMING -> () -> vectorImpl .hamming (bitQueryVector );
221+ };
207222 };
208223 }
209224
0 commit comments