2525import java .util .List ;
2626import java .util .function .Supplier ;
2727
28- import static org .elasticsearch .search .SearchService .DEFAULT_SIZE ;
2928import static org .elasticsearch .search .vectors .KnnSearchBuilderTests .randomVector ;
3029import static org .hamcrest .Matchers .anyOf ;
3130import static org .hamcrest .Matchers .equalTo ;
3231import static org .hamcrest .Matchers .hasSize ;
3332import static org .hamcrest .Matchers .instanceOf ;
34- import static org .mockito .Mockito .mock ;
3533
3634public class RankDocsRetrieverBuilderTests extends ESTestCase {
3735
@@ -48,17 +46,22 @@ private Supplier<RankDoc[]> rankDocsSupplier() {
4846 return () -> rankDocs ;
4947 }
5048
51- private List <RetrieverBuilder > innerRetrievers () {
49+ private List <RetrieverBuilder > innerRetrievers (QueryRewriteContext queryRewriteContext ) throws IOException {
5250 List <RetrieverBuilder > retrievers = new ArrayList <>();
5351 int numRetrievers = randomIntBetween (1 , 10 );
5452 for (int i = 0 ; i < numRetrievers ; i ++) {
5553 if (randomBoolean ()) {
5654 StandardRetrieverBuilder standardRetrieverBuilder = new StandardRetrieverBuilder ();
5755 standardRetrieverBuilder .queryBuilder = RandomQueryBuilder .createQuery (random ());
5856 if (randomBoolean ()) {
59- standardRetrieverBuilder .preFilterQueryBuilders = preFilters ();
57+ standardRetrieverBuilder .preFilterQueryBuilders = preFilters (queryRewriteContext );
6058 }
61- retrievers .add (standardRetrieverBuilder );
59+ // RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten
60+ StandardRetrieverBuilder rewritten = (StandardRetrieverBuilder ) Rewriteable .rewrite (
61+ standardRetrieverBuilder ,
62+ queryRewriteContext
63+ );
64+ retrievers .add (rewritten );
6265 } else {
6366 KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder (
6467 randomAlphaOfLength (10 ),
@@ -69,30 +72,40 @@ private List<RetrieverBuilder> innerRetrievers() {
6972 randomFloat ()
7073 );
7174 if (randomBoolean ()) {
72- knnRetrieverBuilder .preFilterQueryBuilders = preFilters ();
75+ knnRetrieverBuilder .preFilterQueryBuilders = preFilters (queryRewriteContext );
7376 }
7477 knnRetrieverBuilder .rankDocs = rankDocsSupplier ().get ();
75- retrievers .add (knnRetrieverBuilder );
78+ // RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten
79+ KnnRetrieverBuilder rewritten = (KnnRetrieverBuilder ) Rewriteable .rewrite (knnRetrieverBuilder , queryRewriteContext );
80+ retrievers .add (rewritten );
7681 }
7782 }
7883 return retrievers ;
7984 }
8085
81- private List <QueryBuilder > preFilters () {
86+ private List <QueryBuilder > preFilters (QueryRewriteContext queryRewriteContext ) throws IOException {
8287 List <QueryBuilder > preFilters = new ArrayList <>();
8388 int numPreFilters = randomInt (10 );
8489 for (int i = 0 ; i < numPreFilters ; i ++) {
85- preFilters .add (RandomQueryBuilder .createQuery (random ()));
90+ QueryBuilder filter = RandomQueryBuilder .createQuery (random ());
91+ QueryBuilder rewritten = Rewriteable .rewrite (filter , queryRewriteContext );
92+ preFilters .add (rewritten );
8693 }
8794 return preFilters ;
8895 }
8996
90- private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder () {
91- return new RankDocsRetrieverBuilder (randomIntBetween (1 , 100 ), innerRetrievers (), rankDocsSupplier (), preFilters ());
97+ private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder (QueryRewriteContext queryRewriteContext ) throws IOException {
98+ return new RankDocsRetrieverBuilder (
99+ randomIntBetween (1 , 100 ),
100+ innerRetrievers (queryRewriteContext ),
101+ rankDocsSupplier (),
102+ preFilters (queryRewriteContext )
103+ );
92104 }
93105
94- public void testExtractToSearchSourceBuilder () {
95- RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder ();
106+ public void testExtractToSearchSourceBuilder () throws IOException {
107+ QueryRewriteContext queryRewriteContext = new QueryRewriteContext (parserConfig (), null , () -> 0L );
108+ RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder (queryRewriteContext );
96109 SearchSourceBuilder source = new SearchSourceBuilder ();
97110 if (randomBoolean ()) {
98111 source .aggregation (new TermsAggregationBuilder ("name" ).field ("field" ));
@@ -115,16 +128,18 @@ public void testExtractToSearchSourceBuilder() {
115128 assertNull (source .postFilter ());
116129 }
117130
118- public void testTopDocsQuery () {
119- RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder ();
131+ public void testTopDocsQuery () throws IOException {
132+ QueryRewriteContext queryRewriteContext = new QueryRewriteContext (parserConfig (), null , () -> 0L );
133+ RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder (queryRewriteContext );
120134 QueryBuilder topDocs = retriever .topDocsQuery ();
121135 assertNotNull (topDocs );
122136 assertThat (topDocs , instanceOf (BoolQueryBuilder .class ));
123137 assertThat (((BoolQueryBuilder ) topDocs ).should (), hasSize (retriever .sources .size ()));
124138 }
125139
126140 public void testRewrite () throws IOException {
127- RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder ();
141+ QueryRewriteContext queryRewriteContext = new QueryRewriteContext (parserConfig (), null , () -> 0L );
142+ RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder (queryRewriteContext );
128143 boolean compoundAdded = false ;
129144 if (randomBoolean ()) {
130145 compoundAdded = true ;
@@ -136,29 +151,13 @@ public boolean isCompound() {
136151 });
137152 }
138153 SearchSourceBuilder source = new SearchSourceBuilder ().retriever (retriever );
139- QueryRewriteContext queryRewriteContext = mock (QueryRewriteContext .class );
140- int size = source .size () < 0 ? DEFAULT_SIZE : source .size ();
141- if (retriever .rankWindowSize < size ) {
142- if (compoundAdded ) {
143- expectThrows (AssertionError .class , () -> Rewriteable .rewrite (source , queryRewriteContext ));
144- }
154+ if (compoundAdded ) {
155+ expectThrows (AssertionError .class , () -> Rewriteable .rewrite (source , queryRewriteContext ));
145156 } else {
146- if (compoundAdded ) {
147- expectThrows (AssertionError .class , () -> Rewriteable .rewrite (source , queryRewriteContext ));
148- } else {
149- SearchSourceBuilder rewrittenSource = Rewriteable .rewrite (source , queryRewriteContext );
150- assertNull (rewrittenSource .retriever ());
151- assertTrue (rewrittenSource .knnSearch ().isEmpty ());
152- assertThat (rewrittenSource .query (), instanceOf (RankDocsQueryBuilder .class ));
153- if (rewrittenSource .query () instanceof BoolQueryBuilder ) {
154- BoolQueryBuilder bq = (BoolQueryBuilder ) rewrittenSource .query ();
155- assertThat (bq .filter ().size (), equalTo (retriever .preFilterQueryBuilders .size ()));
156- assertThat (bq .must ().size (), equalTo (1 ));
157- assertThat (bq .must ().get (0 ), instanceOf (BoolQueryBuilder .class ));
158- assertThat (bq .should ().size (), equalTo (1 ));
159- assertThat (bq .should ().get (0 ), instanceOf (RankDocsQueryBuilder .class ));
160- }
161- }
157+ SearchSourceBuilder rewrittenSource = Rewriteable .rewrite (source , queryRewriteContext );
158+ assertNull (rewrittenSource .retriever ());
159+ assertTrue (rewrittenSource .knnSearch ().isEmpty ());
160+ assertThat (rewrittenSource .query (), instanceOf (RankDocsQueryBuilder .class ));
162161 }
163162 }
164163}
0 commit comments