1919import org .apache .lucene .store .IndexInput ;
2020import org .apache .lucene .store .IndexOutput ;
2121import org .apache .lucene .store .MMapDirectory ;
22- import org .apache .lucene .util .hnsw .RandomVectorScorer ;
2322import org .apache .lucene .util .hnsw .RandomVectorScorerSupplier ;
23+ import org .apache .lucene .util .hnsw .UpdateableRandomVectorScorer ;
2424import org .apache .lucene .util .quantization .QuantizedByteVectorValues ;
2525import org .apache .lucene .util .quantization .ScalarQuantizer ;
2626
5050// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100)
5151public class VectorScorerFactoryTests extends AbstractVectorTestCase {
5252
53+ private static final float DELTA = 1e-4f ;
54+
5355 // bounds of the range of values that can be seen by int7 scalar quantized vectors
5456 static final byte MIN_INT7_VALUE = 0 ;
5557 static final byte MAX_INT7_VALUE = 127 ;
@@ -99,10 +101,13 @@ void testSimpleImpl(long maxChunkSize) throws IOException {
99101 float scc = values .getScalarQuantizer ().getConstantMultiplier ();
100102 float expected = luceneScore (sim , vec1 , vec2 , scc , vec1Correction , vec2Correction );
101103
102- var luceneSupplier = luceneScoreSupplier (values , VectorSimilarityType .of (sim )).scorer (0 );
104+ var luceneSupplier = luceneScoreSupplier (values , VectorSimilarityType .of (sim )).scorer ();
105+ luceneSupplier .setScoringOrdinal (0 );
103106 assertThat (luceneSupplier .score (1 ), equalTo (expected ));
104107 var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , scc ).get ();
105- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
108+ var scorer = supplier .scorer ();
109+ scorer .setScoringOrdinal (0 );
110+ assertThat (scorer .score (1 ), equalTo (expected ));
106111
107112 if (Runtime .version ().feature () >= 22 ) {
108113 var qScorer = factory .getInt7SQVectorScorer (VectorSimilarityType .of (sim ), values , query1 ).get ();
@@ -134,24 +139,32 @@ public void testNonNegativeDotProduct() throws IOException {
134139 float expected = 0f ;
135140 assertThat (luceneScore (DOT_PRODUCT , vec1 , vec2 , 1 , -5 , -5 ), equalTo (expected ));
136141 var supplier = factory .getInt7SQVectorScorerSupplier (DOT_PRODUCT , in , values , 1 ).get ();
137- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
138- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
142+ var scorer = supplier .scorer ();
143+ scorer .setScoringOrdinal (0 );
144+ assertThat (scorer .score (1 ), equalTo (expected ));
145+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
139146 // max inner product
140147 expected = luceneScore (MAXIMUM_INNER_PRODUCT , vec1 , vec2 , 1 , -5 , -5 );
141148 supplier = factory .getInt7SQVectorScorerSupplier (MAXIMUM_INNER_PRODUCT , in , values , 1 ).get ();
142- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
143- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
149+ scorer = supplier .scorer ();
150+ scorer .setScoringOrdinal (0 );
151+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
152+ assertThat (scorer .score (1 ), equalTo (expected ));
144153 // cosine
145154 expected = 0f ;
146155 assertThat (luceneScore (COSINE , vec1 , vec2 , 1 , -5 , -5 ), equalTo (expected ));
147156 supplier = factory .getInt7SQVectorScorerSupplier (COSINE , in , values , 1 ).get ();
148- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
149- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
157+ scorer = supplier .scorer ();
158+ scorer .setScoringOrdinal (0 );
159+ assertThat (scorer .score (1 ), equalTo (expected ));
160+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
150161 // euclidean
151162 expected = luceneScore (EUCLIDEAN , vec1 , vec2 , 1 , -5 , -5 );
152163 supplier = factory .getInt7SQVectorScorerSupplier (EUCLIDEAN , in , values , 1 ).get ();
153- assertThat (supplier .scorer (0 ).score (1 ), equalTo (expected ));
154- assertThat (supplier .scorer (0 ).score (1 ), greaterThanOrEqualTo (0f ));
164+ scorer = supplier .scorer ();
165+ scorer .setScoringOrdinal (0 );
166+ assertThat (scorer .score (1 ), equalTo (expected ));
167+ assertThat (scorer .score (1 ), greaterThanOrEqualTo (0f ));
155168 }
156169 }
157170 }
@@ -208,7 +221,9 @@ void testRandomSupplier(long maxChunkSize, Function<Integer, byte[]> byteArraySu
208221 var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
209222 float expected = luceneScore (sim , vectors [idx0 ], vectors [idx1 ], correction , offsets [idx0 ], offsets [idx1 ]);
210223 var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
211- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
224+ var scorer = supplier .scorer ();
225+ scorer .setScoringOrdinal (idx0 );
226+ assertThat (scorer .score (idx1 ), equalTo (expected ));
212227 }
213228 }
214229 }
@@ -265,7 +280,7 @@ void testRandomScorerImpl(long maxChunkSize, Function<Integer, float[]> floatArr
265280
266281 var expected = luceneScore (sim , qVectors [idx0 ], qVectors [idx1 ], correction , corrections [idx0 ], corrections [idx1 ]);
267282 var scorer = factory .getInt7SQVectorScorer (VectorSimilarityType .of (sim ), values , vectors [idx0 ]).get ();
268- assertThat (scorer .score (idx1 ), equalTo ( expected ) );
283+ assertEquals (scorer .score (idx1 ), expected , DELTA );
269284 }
270285 }
271286 }
@@ -313,7 +328,9 @@ void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding, Functi
313328 var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
314329 float expected = luceneScore (sim , vectors [idx0 ], vectors [idx1 ], correction , offsets [idx0 ], offsets [idx1 ]);
315330 var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
316- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
331+ var scorer = supplier .scorer ();
332+ scorer .setScoringOrdinal (idx0 );
333+ assertThat (scorer .score (idx1 ), equalTo (expected ));
317334 }
318335 }
319336 }
@@ -352,7 +369,9 @@ public void testLarge() throws IOException {
352369 var values = vectorValues (dims , size , in , VectorSimilarityType .of (sim ));
353370 float expected = luceneScore (sim , vector (idx0 , dims ), vector (idx1 , dims ), correction , off0 , off1 );
354371 var supplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , correction ).get ();
355- assertThat (supplier .scorer (idx0 ).score (idx1 ), equalTo (expected ));
372+ var scorer = supplier .scorer ();
373+ scorer .setScoringOrdinal (idx0 );
374+ assertThat (scorer .score (idx1 ), equalTo (expected ));
356375 }
357376 }
358377 }
@@ -391,8 +410,8 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
391410 var values = vectorValues (dims , 4 , in , VectorSimilarityType .of (sim ));
392411 var scoreSupplier = factory .getInt7SQVectorScorerSupplier (sim , in , values , 1f ).get ();
393412 var tasks = List .<Callable <Optional <Throwable >>>of (
394- new ScoreCallable (scoreSupplier .copy ().scorer (0 ) , 1 , expectedScore1 ),
395- new ScoreCallable (scoreSupplier .copy ().scorer (2 ) , 3 , expectedScore2 )
413+ new ScoreCallable (scoreSupplier .copy ().scorer (), 0 , 1 , expectedScore1 ),
414+ new ScoreCallable (scoreSupplier .copy ().scorer (), 2 , 3 , expectedScore2 )
396415 );
397416 var executor = Executors .newFixedThreadPool (2 );
398417 var results = executor .invokeAll (tasks );
@@ -408,14 +427,19 @@ void testRaceImpl(VectorSimilarityType sim) throws Exception {
408427
409428 static class ScoreCallable implements Callable <Optional <Throwable >> {
410429
411- final RandomVectorScorer scorer ;
430+ final UpdateableRandomVectorScorer scorer ;
412431 final int ord ;
413432 final float expectedScore ;
414433
415- ScoreCallable (RandomVectorScorer scorer , int ord , float expectedScore ) {
416- this .scorer = scorer ;
417- this .ord = ord ;
418- this .expectedScore = expectedScore ;
434+ ScoreCallable (UpdateableRandomVectorScorer scorer , int queryOrd , int ord , float expectedScore ) {
435+ try {
436+ this .scorer = scorer ;
437+ this .scorer .setScoringOrdinal (queryOrd );
438+ this .ord = ord ;
439+ this .expectedScore = expectedScore ;
440+ } catch (IOException e ) {
441+ throw new RuntimeException (e );
442+ }
419443 }
420444
421445 @ Override
0 commit comments