1616import org .apache .lucene .store .IndexOutput ;
1717import org .apache .lucene .store .MMapDirectory ;
1818import org .apache .lucene .store .NIOFSDirectory ;
19- import org .apache .lucene .util .quantization . OptimizedScalarQuantizer ;
19+ import org .apache .lucene .util .VectorUtil ;
2020import org .elasticsearch .common .logging .LogConfigurator ;
2121import org .elasticsearch .core .IOUtils ;
2222import org .elasticsearch .index .codec .vectors .diskbbq .next .ESNextDiskBBQVectorsFormat ;
2323import org .elasticsearch .simdvec .ESNextOSQVectorsScorer ;
2424import org .elasticsearch .simdvec .internal .vectorization .ESVectorizationProvider ;
25+ import org .elasticsearch .simdvec .internal .vectorization .VectorScorerTestUtils ;
2526import org .elasticsearch .xpack .searchablesnapshots .store .SearchableSnapshotDirectoryFactory ;
2627import org .openjdk .jmh .annotations .Benchmark ;
2728import org .openjdk .jmh .annotations .BenchmarkMode ;
4243import java .util .Random ;
4344import java .util .concurrent .TimeUnit ;
4445
46+ import static org .elasticsearch .simdvec .internal .vectorization .VectorScorerTestUtils .createOSQIndexData ;
47+ import static org .elasticsearch .simdvec .internal .vectorization .VectorScorerTestUtils .createOSQQueryData ;
48+ import static org .elasticsearch .simdvec .internal .vectorization .VectorScorerTestUtils .randomVector ;
49+ import static org .elasticsearch .simdvec .internal .vectorization .VectorScorerTestUtils .writeBulkOSQVectorData ;
50+
4551@ BenchmarkMode (Mode .Throughput )
4652@ OutputTimeUnit (TimeUnit .MILLISECONDS )
4753@ State (Scope .Benchmark )
@@ -72,7 +78,7 @@ public enum VectorImplementation {
7278 public int dims ;
7379
7480 @ Param ({ "1" , "2" , "4" })
75- public int bits ;
81+ public byte bits ;
7682
7783 int bulkSize = ESNextOSQVectorsScorer .BULK_SIZE ;
7884
@@ -90,9 +96,7 @@ public enum VectorImplementation {
9096
9197 int length ;
9298
93- byte [][] binaryVectors ;
94- byte [][] binaryQueries ;
95- OptimizedScalarQuantizer .QuantizationResult result ;
99+ VectorScorerTestUtils .OSQVectorData [] binaryQueries ;
96100 float centroidDp ;
97101
98102 byte [] scratch ;
@@ -111,17 +115,12 @@ public void setup() throws IOException {
111115 }
112116
113117 void setup (Random random ) throws IOException {
114- this .length = switch (bits ) {
115- case 1 -> ESNextDiskBBQVectorsFormat .QuantEncoding .ONE_BIT_4BIT_QUERY .getDocPackedLength (dims );
116- case 2 -> ESNextDiskBBQVectorsFormat .QuantEncoding .TWO_BIT_4BIT_QUERY .getDocPackedLength (dims );
117- case 4 -> ESNextDiskBBQVectorsFormat .QuantEncoding .FOUR_BIT_SYMMETRIC .getDocPackedLength (dims );
118- default -> throw new IllegalArgumentException ("Unsupported bits: " + bits );
119- };
118+ this .length = ESNextDiskBBQVectorsFormat .QuantEncoding .fromBits (bits ).getDocPackedLength (dims );
120119
121- binaryVectors = new byte [ numVectors ][ length ];
122- for ( byte [] binaryVector : binaryVectors ) {
123- random . nextBytes ( binaryVector );
124- }
120+ final float [] centroid = new float [ dims ];
121+ randomVector ( random , centroid , similarityFunction );
122+
123+ var quantizer = new org . elasticsearch . index . codec . vectors . OptimizedScalarQuantizer ( similarityFunction );
125124
126125 directory = switch (directoryType ) {
127126 case MMAP -> new MMapDirectory (createTempDirectory ("vectorDataMmap" ));
@@ -130,35 +129,27 @@ void setup(Random random) throws IOException {
130129 };
131130
132131 try (IndexOutput output = directory .createOutput ("vectors" , IOContext .DEFAULT )) {
133- byte [] correctionBytes = new byte [ 16 * bulkSize ];
132+ VectorScorerTestUtils . OSQVectorData [] vectors = new VectorScorerTestUtils . OSQVectorData [ bulkSize ];
134133 for (int i = 0 ; i < numVectors ; i += bulkSize ) {
135134 for (int j = 0 ; j < bulkSize ; j ++) {
136- output .writeBytes (binaryVectors [i + j ], 0 , binaryVectors [i + j ].length );
135+ var vector = new float [dims ];
136+ randomVector (random , vector , similarityFunction );
137+ vectors [j ] = createOSQIndexData (vector , centroid , quantizer , dims , bits , length );
137138 }
138- random .nextBytes (correctionBytes );
139- output .writeBytes (correctionBytes , 0 , correctionBytes .length );
139+ writeBulkOSQVectorData (bulkSize , output , vectors );
140140 }
141141 CodecUtil .writeFooter (output );
142142 }
143143 input = directory .openInput ("vectors" , IOContext .DEFAULT );
144- int binaryQueryLength = switch (bits ) {
145- case 1 -> ESNextDiskBBQVectorsFormat .QuantEncoding .ONE_BIT_4BIT_QUERY .getQueryPackedLength (dims );
146- case 2 -> ESNextDiskBBQVectorsFormat .QuantEncoding .TWO_BIT_4BIT_QUERY .getQueryPackedLength (dims );
147- case 4 -> ESNextDiskBBQVectorsFormat .QuantEncoding .FOUR_BIT_SYMMETRIC .getQueryPackedLength (dims );
148- default -> throw new IllegalArgumentException ("Unsupported bits: " + bits );
149- };
144+ int binaryQueryLength = ESNextDiskBBQVectorsFormat .QuantEncoding .fromBits (bits ).getQueryPackedLength (dims );
150145
151- binaryQueries = new byte [numVectors ][binaryQueryLength ];
152- for (byte [] binaryQuery : binaryQueries ) {
153- random .nextBytes (binaryQuery );
146+ binaryQueries = new VectorScorerTestUtils .OSQVectorData [numVectors ];
147+ var query = new float [dims ];
148+ for (int i = 0 ; i < numVectors ; ++i ) {
149+ randomVector (random , query , similarityFunction );
150+ binaryQueries [i ] = createOSQQueryData (query , centroid , quantizer , dims , (byte ) 4 , binaryQueryLength );
154151 }
155- result = new OptimizedScalarQuantizer .QuantizationResult (
156- random .nextFloat (),
157- random .nextFloat (),
158- random .nextFloat (),
159- Short .toUnsignedInt ((short ) random .nextInt ())
160- );
161- centroidDp = random .nextFloat ();
152+ centroidDp = VectorUtil .dotProduct (centroid , centroid );
162153
163154 scratch = new byte [length ];
164155 final int docBits ;
@@ -202,14 +193,14 @@ public float[] score() throws IOException {
202193 for (int j = 0 ; j < numQueries ; j ++) {
203194 input .seek (0 );
204195 for (int i = 0 ; i < numVectors ; i ++) {
205- float qDist = scorer .quantizeScore (binaryQueries [j ]);
196+ float qDist = scorer .quantizeScore (binaryQueries [j ]. quantizedVector () );
206197 input .readFloats (corrections , 0 , corrections .length );
207198 int addition = Short .toUnsignedInt (input .readShort ());
208199 float score = scorer .score (
209- result .lowerInterval (),
210- result .upperInterval (),
211- result .quantizedComponentSum (),
212- result .additionalCorrection (),
200+ binaryQueries [ j ] .lowerInterval (),
201+ binaryQueries [ j ] .upperInterval (),
202+ binaryQueries [ j ] .quantizedComponentSum (),
203+ binaryQueries [ j ] .additionalCorrection (),
213204 similarityFunction ,
214205 centroidDp ,
215206 corrections [0 ],
@@ -231,11 +222,11 @@ public float[] bulkScore() throws IOException {
231222 input .seek (0 );
232223 for (int i = 0 ; i < numVectors ; i += scratchScores .length ) {
233224 scorer .scoreBulk (
234- binaryQueries [j ],
235- result .lowerInterval (),
236- result .upperInterval (),
237- result .quantizedComponentSum (),
238- result .additionalCorrection (),
225+ binaryQueries [j ]. quantizedVector () ,
226+ binaryQueries [ j ] .lowerInterval (),
227+ binaryQueries [ j ] .upperInterval (),
228+ binaryQueries [ j ] .quantizedComponentSum (),
229+ binaryQueries [ j ] .additionalCorrection (),
239230 similarityFunction ,
240231 centroidDp ,
241232 scratchScores
0 commit comments