2020import org .elasticsearch .index .mapper .FieldTypeTestCase ;
2121import org .elasticsearch .index .mapper .MappedFieldType ;
2222import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .DenseVectorFieldType ;
23+ import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .ElementType ;
2324import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .VectorSimilarity ;
2425import org .elasticsearch .search .DocValueFormat ;
2526import org .elasticsearch .search .vectors .DenseVectorQuery ;
27+ import org .elasticsearch .search .vectors .ESKnnByteVectorQuery ;
28+ import org .elasticsearch .search .vectors .ESKnnFloatVectorQuery ;
29+ import org .elasticsearch .search .vectors .RescoreKnnVectorQuery ;
2630import org .elasticsearch .search .vectors .VectorData ;
2731
2832import java .io .IOException ;
3135import java .util .Set ;
3236
3337import static org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .BBQ_MIN_DIMS ;
38+ import static org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .ElementType .BYTE ;
39+ import static org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .ElementType .FLOAT ;
3440import static org .hamcrest .Matchers .containsString ;
41+ import static org .hamcrest .Matchers .equalTo ;
3542import static org .hamcrest .Matchers .instanceOf ;
43+ import static org .hamcrest .Matchers .is ;
3644
3745public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
3846 private final boolean indexed ;
@@ -69,11 +77,27 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() {
6977 );
7078 }
7179
80+ private DenseVectorFieldMapper .IndexOptions randomIndexOptionsHnswQuantized () {
81+ return randomFrom (
82+ new DenseVectorFieldMapper .Int8HnswIndexOptions (
83+ randomIntBetween (1 , 100 ),
84+ randomIntBetween (1 , 10_000 ),
85+ randomFrom ((Float ) null , 0f , (float ) randomDoubleBetween (0.9 , 1.0 , true ))
86+ ),
87+ new DenseVectorFieldMapper .Int4HnswIndexOptions (
88+ randomIntBetween (1 , 100 ),
89+ randomIntBetween (1 , 10_000 ),
90+ randomFrom ((Float ) null , 0f , (float ) randomDoubleBetween (0.9 , 1.0 , true ))
91+ ),
92+ new DenseVectorFieldMapper .BBQHnswIndexOptions (randomIntBetween (1 , 100 ), randomIntBetween (1 , 10_000 ))
93+ );
94+ }
95+
7296 private DenseVectorFieldType createFloatFieldType () {
7397 return new DenseVectorFieldType (
7498 "f" ,
7599 IndexVersion .current (),
76- DenseVectorFieldMapper . ElementType . FLOAT ,
100+ FLOAT ,
77101 BBQ_MIN_DIMS ,
78102 indexed ,
79103 VectorSimilarity .COSINE ,
@@ -86,7 +110,7 @@ private DenseVectorFieldType createByteFieldType() {
86110 return new DenseVectorFieldType (
87111 "f" ,
88112 IndexVersion .current (),
89- DenseVectorFieldMapper . ElementType . BYTE ,
113+ BYTE ,
90114 5 ,
91115 true ,
92116 VectorSimilarity .COSINE ,
@@ -159,7 +183,7 @@ public void testCreateNestedKnnQuery() {
159183 DenseVectorFieldType field = new DenseVectorFieldType (
160184 "f" ,
161185 IndexVersion .current (),
162- DenseVectorFieldMapper . ElementType . FLOAT ,
186+ FLOAT ,
163187 dims ,
164188 true ,
165189 VectorSimilarity .COSINE ,
@@ -177,7 +201,7 @@ public void testCreateNestedKnnQuery() {
177201 DenseVectorFieldType field = new DenseVectorFieldType (
178202 "f" ,
179203 IndexVersion .current (),
180- DenseVectorFieldMapper . ElementType . BYTE ,
204+ BYTE ,
181205 dims ,
182206 true ,
183207 VectorSimilarity .COSINE ,
@@ -209,7 +233,7 @@ public void testExactKnnQuery() {
209233 DenseVectorFieldType field = new DenseVectorFieldType (
210234 "f" ,
211235 IndexVersion .current (),
212- DenseVectorFieldMapper . ElementType . FLOAT ,
236+ FLOAT ,
213237 dims ,
214238 true ,
215239 VectorSimilarity .COSINE ,
@@ -227,7 +251,7 @@ public void testExactKnnQuery() {
227251 DenseVectorFieldType field = new DenseVectorFieldType (
228252 "f" ,
229253 IndexVersion .current (),
230- DenseVectorFieldMapper . ElementType . BYTE ,
254+ BYTE ,
231255 dims ,
232256 true ,
233257 VectorSimilarity .COSINE ,
@@ -247,7 +271,7 @@ public void testFloatCreateKnnQuery() {
247271 DenseVectorFieldType unindexedField = new DenseVectorFieldType (
248272 "f" ,
249273 IndexVersion .current (),
250- DenseVectorFieldMapper . ElementType . FLOAT ,
274+ FLOAT ,
251275 4 ,
252276 false ,
253277 VectorSimilarity .COSINE ,
@@ -271,7 +295,7 @@ public void testFloatCreateKnnQuery() {
271295 DenseVectorFieldType dotProductField = new DenseVectorFieldType (
272296 "f" ,
273297 IndexVersion .current (),
274- DenseVectorFieldMapper . ElementType . FLOAT ,
298+ FLOAT ,
275299 BBQ_MIN_DIMS ,
276300 true ,
277301 VectorSimilarity .DOT_PRODUCT ,
@@ -291,7 +315,7 @@ public void testFloatCreateKnnQuery() {
291315 DenseVectorFieldType cosineField = new DenseVectorFieldType (
292316 "f" ,
293317 IndexVersion .current (),
294- DenseVectorFieldMapper . ElementType . FLOAT ,
318+ FLOAT ,
295319 BBQ_MIN_DIMS ,
296320 true ,
297321 VectorSimilarity .COSINE ,
@@ -310,7 +334,7 @@ public void testCreateKnnQueryMaxDims() {
310334 DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType (
311335 "f" ,
312336 IndexVersion .current (),
313- DenseVectorFieldMapper . ElementType . FLOAT ,
337+ FLOAT ,
314338 4096 ,
315339 true ,
316340 VectorSimilarity .COSINE ,
@@ -329,7 +353,7 @@ public void testCreateKnnQueryMaxDims() {
329353 DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType (
330354 "f" ,
331355 IndexVersion .current (),
332- DenseVectorFieldMapper . ElementType . BYTE ,
356+ BYTE ,
333357 4096 ,
334358 true ,
335359 VectorSimilarity .COSINE ,
@@ -350,7 +374,7 @@ public void testByteCreateKnnQuery() {
350374 DenseVectorFieldType unindexedField = new DenseVectorFieldType (
351375 "f" ,
352376 IndexVersion .current (),
353- DenseVectorFieldMapper . ElementType . BYTE ,
377+ BYTE ,
354378 3 ,
355379 false ,
356380 VectorSimilarity .COSINE ,
@@ -366,7 +390,7 @@ public void testByteCreateKnnQuery() {
366390 DenseVectorFieldType cosineField = new DenseVectorFieldType (
367391 "f" ,
368392 IndexVersion .current (),
369- DenseVectorFieldMapper . ElementType . BYTE ,
393+ BYTE ,
370394 3 ,
371395 true ,
372396 VectorSimilarity .COSINE ,
@@ -385,4 +409,84 @@ public void testByteCreateKnnQuery() {
385409 );
386410 assertThat (e .getMessage (), containsString ("The [cosine] similarity does not support vectors with zero magnitude." ));
387411 }
412+
413+ public void testRescoreOversampleUsedWithoutQuantization () {
414+ ElementType elementType = randomFrom (BYTE , FLOAT );
415+ DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType (
416+ "f" ,
417+ IndexVersion .current (),
418+ elementType ,
419+ 3 ,
420+ true ,
421+ VectorSimilarity .COSINE ,
422+ randomIndexOptionsNonQuantized (),
423+ Collections .emptyMap ()
424+ );
425+
426+ Query knnQuery = nonQuantizedField .createKnnQuery (
427+ new VectorData (null , new byte []{1 , 4 , 10 }),
428+ 10 ,
429+ 100 ,
430+ randomFloatBetween (1.0F , 10.0F , false ),
431+ null ,
432+ null ,
433+ null
434+ );
435+
436+ if (elementType == BYTE ) {
437+ ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery ) knnQuery ;
438+ assertThat (esKnnQuery .getK (), is (100 ));
439+ assertThat (esKnnQuery .kParam (), is (10 ));
440+ } else {
441+ ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery ) knnQuery ;
442+ assertThat (esKnnQuery .getK (), is (100 ));
443+ assertThat (esKnnQuery .kParam (), is (10 ));
444+ }
445+ }
446+
447+ public void testRescoreOversampleModifiesKnnParams () {
448+ DenseVectorFieldType fieldType = new DenseVectorFieldType (
449+ "f" ,
450+ IndexVersion .current (),
451+ randomFrom (BYTE , FLOAT ),
452+ 3 ,
453+ true ,
454+ VectorSimilarity .COSINE ,
455+ randomIndexOptionsHnswQuantized (),
456+ Collections .emptyMap ()
457+ );
458+
459+ // Total results is k, internal k is multiplied by oversample
460+ checkRescoreQueryParameters (fieldType , 10 , 200 , 2.5F , 10 , 25 , 200 );
461+ // If numCands < k, update numCands to k
462+ checkRescoreQueryParameters (fieldType , 10 , 20 , 2.5F , 10 , 25 , 25 );
463+ // Oversampling limit
464+ checkRescoreQueryParameters (fieldType , 1000 , 1000 , 11.0F , 1000 , 10000 , 10000 );
465+ checkRescoreQueryParameters (fieldType , 5000 , 7500 , 2.5F , 5000 , 10000 , 10000 );
466+ }
467+
468+ private static void checkRescoreQueryParameters (
469+ DenseVectorFieldType fieldType ,
470+ int k ,
471+ int candidates ,
472+ float oversample ,
473+ int expectedResults ,
474+ int expectedK ,
475+ int expectedCandidates
476+ ) {
477+ Query query = fieldType .createKnnQuery (new VectorData (null , new byte [] { 1 , 4 , 10 }), k , candidates , oversample , null , null , null );
478+
479+ RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery ) query ;
480+ if (fieldType .getElementType () == BYTE ) {
481+ ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery ) rescoreQuery .innerQuery ();
482+ assertThat ("Unexpected total results" , rescoreQuery .k (), equalTo (expectedResults ));
483+ assertThat ("Unexpected k parameter" , esKnnQuery .kParam (), equalTo (expectedK ));
484+ assertThat ("Unexpected candidates" , esKnnQuery .getK (), equalTo (expectedCandidates ));
485+ } else {
486+ ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery ) rescoreQuery .innerQuery ();
487+ assertThat ("Unexpected total results" , rescoreQuery .k (), equalTo (expectedResults ));
488+ assertThat ("Unexpected k parameter" , esKnnQuery .kParam (), equalTo (expectedK ));
489+ assertThat ("Unexpected candidates" , esKnnQuery .getK (), equalTo (expectedCandidates ));
490+ }
491+ }
388492}
0 commit comments