|
11 | 11 |
|
12 | 12 | import org.apache.lucene.search.KnnByteVectorQuery; |
13 | 13 | import org.apache.lucene.search.KnnFloatVectorQuery; |
| 14 | +import org.apache.lucene.search.PatienceKnnVectorQuery; |
14 | 15 | import org.apache.lucene.search.Query; |
15 | 16 | import org.apache.lucene.search.join.BitSetProducer; |
16 | 17 | import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; |
@@ -478,7 +479,7 @@ public void testCreateKnnQueryMaxDims() { |
478 | 479 | null, |
479 | 480 | randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) |
480 | 481 | ); |
481 | | - assertThat(query, instanceOf(KnnByteVectorQuery.class)); |
| 482 | + assertThat(query, instanceOf(ESKnnByteVectorQuery.class)); |
482 | 483 | } |
483 | 484 | } |
484 | 485 |
|
@@ -577,16 +578,20 @@ public void testRescoreOversampleUsedWithoutQuantization() { |
577 | 578 | ); |
578 | 579 |
|
579 | 580 | if (elementType == BYTE) { |
580 | | - KnnByteVectorQuery knnByteVectorQuery = (KnnByteVectorQuery) knnQuery; |
581 | | - assertThat(knnByteVectorQuery.getK(), is(100)); |
582 | | - if (knnByteVectorQuery instanceof ESKnnByteVectorQuery esKnnByteVectorQuery) { |
583 | | - assertThat(esKnnByteVectorQuery.kParam(), is(10)); |
| 581 | + if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { |
| 582 | + assertThat(patienceKnnVectorQuery.getK(), is(100)); |
| 583 | + } else { |
| 584 | + ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery; |
| 585 | + assertThat(knnByteVectorQuery.getK(), is(100)); |
| 586 | + assertThat(knnByteVectorQuery.kParam(), is(10)); |
584 | 587 | } |
585 | 588 | } else { |
586 | | - KnnFloatVectorQuery knnFloatVectorQuery = (KnnFloatVectorQuery) knnQuery; |
587 | | - assertThat(knnFloatVectorQuery.getK(), is(100)); |
588 | | - if (knnFloatVectorQuery instanceof ESKnnFloatVectorQuery esKnnFloatVectorQuery) { |
589 | | - assertThat(esKnnFloatVectorQuery.kParam(), is(10)); |
| 589 | + if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { |
| 590 | + assertThat(patienceKnnVectorQuery.getK(), is(100)); |
| 591 | + } else { |
| 592 | + ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery; |
| 593 | + assertThat(knnFloatVectorQuery.getK(), is(100)); |
| 594 | + assertThat(knnFloatVectorQuery.kParam(), is(10)); |
590 | 595 | } |
591 | 596 | } |
592 | 597 | } |
@@ -635,7 +640,7 @@ public void testRescoreOversampleQueryOverrides() { |
635 | 640 | null, |
636 | 641 | randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) |
637 | 642 | ); |
638 | | - assertTrue(query instanceof KnnFloatVectorQuery); |
| 643 | + assertTrue(query instanceof ESKnnFloatVectorQuery); |
639 | 644 |
|
640 | 645 | // verify we can override a `0` to a positive number |
641 | 646 | fieldType = new DenseVectorFieldType( |
@@ -738,11 +743,14 @@ private static void checkRescoreQueryParameters( |
738 | 743 | randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()) |
739 | 744 | ); |
740 | 745 | RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; |
741 | | - KnnFloatVectorQuery knnQuery = (KnnFloatVectorQuery) rescoreQuery.innerQuery(); |
742 | | - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); |
743 | | - assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates)); |
744 | | - if (knnQuery instanceof ESKnnFloatVectorQuery esKnnFloatVectorQuery) { |
745 | | - assertThat("Unexpected k parameter", esKnnFloatVectorQuery.kParam(), equalTo(expectedK)); |
| 746 | + Query innerQuery = rescoreQuery.innerQuery(); |
| 747 | + if (innerQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) { |
| 748 | + assertThat("Unexpected candidates", patienceKnnVectorQuery.getK(), equalTo(expectedCandidates)); |
| 749 | + } else { |
| 750 | + ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery; |
| 751 | + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); |
| 752 | + assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates)); |
| 753 | + assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK)); |
746 | 754 | } |
747 | 755 | } |
748 | 756 | } |
0 commit comments