2020import org .elasticsearch .simdvec .ES91Int4VectorsScorer ;
2121import org .elasticsearch .simdvec .ES91OSQVectorsScorer ;
2222
23- import static org .hamcrest .Matchers .lessThan ;
23+ import java .io .IOException ;
24+
25+ import static org .hamcrest .Matchers .greaterThan ;
2426
2527public class ES91Int4VectorScorerTests extends BaseVectorizationTests {
2628
@@ -130,31 +132,59 @@ public void testInt4ScoreBulk() throws Exception {
130132 // only even dimensions are supported
131133 final int dimensions = random ().nextInt (1 , 1000 ) * 2 ;
132134 final int numVectors = random ().nextInt (1 , 10 ) * ES91Int4VectorsScorer .BULK_SIZE ;
133- final byte [] vector = new byte [ES91Int4VectorsScorer .BULK_SIZE * dimensions ];
134- final byte [] corrections = new byte [ES91Int4VectorsScorer .BULK_SIZE * 14 ];
135+ final float [][] vectors = new float [numVectors ][dimensions ];
136+ final int [] quantizedScratch = new int [dimensions ];
137+ final byte [] quantizeVector = new byte [dimensions ];
138+ final float [] centroid = new float [dimensions ];
139+ VectorSimilarityFunction similarityFunction = randomFrom (VectorSimilarityFunction .values ());
140+ for (int i = 0 ; i < dimensions ; i ++) {
141+ centroid [i ] = random ().nextFloat ();
142+ }
143+ if (similarityFunction != VectorSimilarityFunction .EUCLIDEAN ) {
144+ VectorUtil .l2normalize (centroid );
145+ }
146+
147+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer (similarityFunction );
135148 try (Directory dir = new MMapDirectory (createTempDir ())) {
136149 try (IndexOutput out = dir .createOutput ("tests.bin" , IOContext .DEFAULT )) {
150+ OptimizedScalarQuantizer .QuantizationResult [] results =
151+ new OptimizedScalarQuantizer .QuantizationResult [ES91Int4VectorsScorer .BULK_SIZE ];
137152 for (int i = 0 ; i < numVectors ; i += ES91Int4VectorsScorer .BULK_SIZE ) {
138- for (int j = 0 ; j < ES91Int4VectorsScorer .BULK_SIZE * dimensions ; j ++) {
139- vector [j ] = (byte ) random ().nextInt (16 ); // 4-bit quantization
153+ for (int j = 0 ; j < ES91Int4VectorsScorer .BULK_SIZE ; j ++) {
154+ for (int k = 0 ; k < dimensions ; k ++) {
155+ vectors [i + j ][k ] = random ().nextFloat ();
156+ }
157+ if (similarityFunction != VectorSimilarityFunction .EUCLIDEAN ) {
158+ VectorUtil .l2normalize (vectors [i + j ]);
159+ }
160+ results [j ] = quantizer .scalarQuantize (vectors [i + j ].clone (), quantizedScratch , (byte ) 4 , centroid );
161+ for (int k = 0 ; k < dimensions ; k ++) {
162+ quantizeVector [k ] = (byte ) quantizedScratch [k ];
163+ }
164+ out .writeBytes (quantizeVector , 0 , dimensions );
140165 }
141- out .writeBytes (vector , 0 , vector .length );
142- random ().nextBytes (corrections );
143- out .writeBytes (corrections , 0 , corrections .length );
166+ writeCorrections (results , out );
144167 }
145168 }
146- final byte [] query = new byte [dimensions ];
169+ final float [] query = new float [dimensions ];
170+ final byte [] quantizeQuery = new byte [dimensions ];
147171 for (int j = 0 ; j < dimensions ; j ++) {
148- query [j ] = ( byte ) random ().nextInt ( 16 ); // 4-bit quantization
172+ query [j ] = random ().nextFloat ();
149173 }
150- OptimizedScalarQuantizer .QuantizationResult queryCorrections = new OptimizedScalarQuantizer .QuantizationResult (
151- random ().nextFloat (),
152- random ().nextFloat (),
153- random ().nextFloat (),
154- Short .toUnsignedInt ((short ) random ().nextInt ())
174+ if (similarityFunction != VectorSimilarityFunction .EUCLIDEAN ) {
175+ VectorUtil .l2normalize (query );
176+ }
177+ OptimizedScalarQuantizer .QuantizationResult queryCorrections = quantizer .scalarQuantize (
178+ query .clone (),
179+ quantizedScratch ,
180+ (byte ) 4 ,
181+ centroid
155182 );
156- float centroidDp = random ().nextFloat ();
157- VectorSimilarityFunction similarityFunction = randomFrom (VectorSimilarityFunction .values ());
183+ for (int j = 0 ; j < dimensions ; j ++) {
184+ quantizeQuery [j ] = (byte ) quantizedScratch [j ];
185+ }
186+ float centroidDp = VectorUtil .dotProduct (centroid , centroid );
187+
158188 try (IndexInput in = dir .openInput ("tests.bin" , IOContext .DEFAULT )) {
159189 // Work on a slice that has just the right number of bytes to make the test fail with an
160190 // index-out-of-bounds in case the implementation reads more than the allowed number of
@@ -166,7 +196,7 @@ public void testInt4ScoreBulk() throws Exception {
166196 float [] scoresPanama = new float [ES91Int4VectorsScorer .BULK_SIZE ];
167197 for (int i = 0 ; i < numVectors ; i += ES91Int4VectorsScorer .BULK_SIZE ) {
168198 defaultScorer .scoreBulk (
169- query ,
199+ quantizeQuery ,
170200 queryCorrections .lowerInterval (),
171201 queryCorrections .upperInterval (),
172202 queryCorrections .quantizedComponentSum (),
@@ -176,7 +206,7 @@ public void testInt4ScoreBulk() throws Exception {
176206 scoresDefault
177207 );
178208 panamaScorer .scoreBulk (
179- query ,
209+ quantizeQuery ,
180210 queryCorrections .lowerInterval (),
181211 queryCorrections .upperInterval (),
182212 queryCorrections .quantizedComponentSum (),
@@ -186,29 +216,34 @@ public void testInt4ScoreBulk() throws Exception {
186216 scoresPanama
187217 );
188218 for (int j = 0 ; j < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
189- if (scoresDefault [j ] == scoresPanama [j ]) {
190- continue ;
191- }
192- if (scoresDefault [j ] > (1000 * Byte .MAX_VALUE )) {
193- float diff = Math .abs (scoresDefault [j ] - scoresPanama [j ]);
194- assertThat (
195- "defaultScores: " + scoresDefault [j ] + " bulkScores: " + scoresPanama [j ],
196- diff / scoresDefault [j ],
197- lessThan (1e-5f )
198- );
199- assertThat (
200- "defaultScores: " + scoresDefault [j ] + " bulkScores: " + scoresPanama [j ],
201- diff / scoresPanama [j ],
202- lessThan (1e-5f )
203- );
204- } else {
205- assertEquals (scoresDefault [j ], scoresPanama [j ], 1e-2f );
206- }
219+ assertEquals (scoresDefault [j ], scoresPanama [j ], 1e-2f );
220+ float realSimilarity = similarityFunction .compare (vectors [i + j ], query );
221+ float accuracy = realSimilarity > scoresDefault [j ]
222+ ? scoresDefault [j ] / realSimilarity
223+ : realSimilarity / scoresDefault [j ];
224+ assertThat (accuracy , greaterThan (0.90f ));
207225 }
208226 assertEquals (in .getFilePointer (), slice .getFilePointer ());
209227 }
210228 assertEquals ((long ) (dimensions + 14 ) * numVectors , in .getFilePointer ());
211229 }
212230 }
213231 }
232+
233+ private static void writeCorrections (OptimizedScalarQuantizer .QuantizationResult [] corrections , IndexOutput out ) throws IOException {
234+ for (OptimizedScalarQuantizer .QuantizationResult correction : corrections ) {
235+ out .writeInt (Float .floatToIntBits (correction .lowerInterval ()));
236+ }
237+ for (OptimizedScalarQuantizer .QuantizationResult correction : corrections ) {
238+ out .writeInt (Float .floatToIntBits (correction .upperInterval ()));
239+ }
240+ for (OptimizedScalarQuantizer .QuantizationResult correction : corrections ) {
241+ int targetComponentSum = correction .quantizedComponentSum ();
242+ assert targetComponentSum >= 0 && targetComponentSum <= 0xffff ;
243+ out .writeShort ((short ) targetComponentSum );
244+ }
245+ for (OptimizedScalarQuantizer .QuantizationResult correction : corrections ) {
246+ out .writeInt (Float .floatToIntBits (correction .additionalCorrection ()));
247+ }
248+ }
214249}
0 commit comments