Skip to content

Commit 4e41234

Browse files
authored
Updating tests to account for rewritting nested retrievers (#114502)
1 parent 9bdc590 commit 4e41234

File tree

2 files changed

+37
-41
lines changed

2 files changed

+37
-41
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ tests:
357357
- class: org.elasticsearch.xpack.inference.InferenceRestIT
358358
method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint}
359359
issue: https://github.com/elastic/elasticsearch/issues/114376
360-
- class: org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests
361-
method: testRewrite
362-
issue: https://github.com/elastic/elasticsearch/issues/114467
363360
- class: org.elasticsearch.xpack.logsdb.LogsdbTestSuiteIT
364361
issue: https://github.com/elastic/elasticsearch/issues/114471
365362
- class: org.elasticsearch.packaging.test.DockerTests

server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,11 @@
2525
import java.util.List;
2626
import java.util.function.Supplier;
2727

28-
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
2928
import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector;
3029
import static org.hamcrest.Matchers.anyOf;
3130
import static org.hamcrest.Matchers.equalTo;
3231
import static org.hamcrest.Matchers.hasSize;
3332
import static org.hamcrest.Matchers.instanceOf;
34-
import static org.mockito.Mockito.mock;
3533

3634
public class RankDocsRetrieverBuilderTests extends ESTestCase {
3735

@@ -48,17 +46,22 @@ private Supplier<RankDoc[]> rankDocsSupplier() {
4846
return () -> rankDocs;
4947
}
5048

51-
private List<RetrieverBuilder> innerRetrievers() {
49+
private List<RetrieverBuilder> innerRetrievers(QueryRewriteContext queryRewriteContext) throws IOException {
5250
List<RetrieverBuilder> retrievers = new ArrayList<>();
5351
int numRetrievers = randomIntBetween(1, 10);
5452
for (int i = 0; i < numRetrievers; i++) {
5553
if (randomBoolean()) {
5654
StandardRetrieverBuilder standardRetrieverBuilder = new StandardRetrieverBuilder();
5755
standardRetrieverBuilder.queryBuilder = RandomQueryBuilder.createQuery(random());
5856
if (randomBoolean()) {
59-
standardRetrieverBuilder.preFilterQueryBuilders = preFilters();
57+
standardRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext);
6058
}
61-
retrievers.add(standardRetrieverBuilder);
59+
// RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten
60+
StandardRetrieverBuilder rewritten = (StandardRetrieverBuilder) Rewriteable.rewrite(
61+
standardRetrieverBuilder,
62+
queryRewriteContext
63+
);
64+
retrievers.add(rewritten);
6265
} else {
6366
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
6467
randomAlphaOfLength(10),
@@ -69,30 +72,40 @@ private List<RetrieverBuilder> innerRetrievers() {
6972
randomFloat()
7073
);
7174
if (randomBoolean()) {
72-
knnRetrieverBuilder.preFilterQueryBuilders = preFilters();
75+
knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext);
7376
}
7477
knnRetrieverBuilder.rankDocs = rankDocsSupplier().get();
75-
retrievers.add(knnRetrieverBuilder);
78+
// RankDocsRetrieverBuilder assumes that the inner retrievers are already rewritten
79+
KnnRetrieverBuilder rewritten = (KnnRetrieverBuilder) Rewriteable.rewrite(knnRetrieverBuilder, queryRewriteContext);
80+
retrievers.add(rewritten);
7681
}
7782
}
7883
return retrievers;
7984
}
8085

81-
private List<QueryBuilder> preFilters() {
86+
private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) throws IOException {
8287
List<QueryBuilder> preFilters = new ArrayList<>();
8388
int numPreFilters = randomInt(10);
8489
for (int i = 0; i < numPreFilters; i++) {
85-
preFilters.add(RandomQueryBuilder.createQuery(random()));
90+
QueryBuilder filter = RandomQueryBuilder.createQuery(random());
91+
QueryBuilder rewritten = Rewriteable.rewrite(filter, queryRewriteContext);
92+
preFilters.add(rewritten);
8693
}
8794
return preFilters;
8895
}
8996

90-
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder() {
91-
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(), rankDocsSupplier(), preFilters());
97+
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
98+
return new RankDocsRetrieverBuilder(
99+
randomIntBetween(1, 100),
100+
innerRetrievers(queryRewriteContext),
101+
rankDocsSupplier(),
102+
preFilters(queryRewriteContext)
103+
);
92104
}
93105

94-
public void testExtractToSearchSourceBuilder() {
95-
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder();
106+
public void testExtractToSearchSourceBuilder() throws IOException {
107+
QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L);
108+
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext);
96109
SearchSourceBuilder source = new SearchSourceBuilder();
97110
if (randomBoolean()) {
98111
source.aggregation(new TermsAggregationBuilder("name").field("field"));
@@ -115,16 +128,18 @@ public void testExtractToSearchSourceBuilder() {
115128
assertNull(source.postFilter());
116129
}
117130

118-
public void testTopDocsQuery() {
119-
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder();
131+
public void testTopDocsQuery() throws IOException {
132+
QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L);
133+
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext);
120134
QueryBuilder topDocs = retriever.topDocsQuery();
121135
assertNotNull(topDocs);
122136
assertThat(topDocs, instanceOf(BoolQueryBuilder.class));
123137
assertThat(((BoolQueryBuilder) topDocs).should(), hasSize(retriever.sources.size()));
124138
}
125139

126140
public void testRewrite() throws IOException {
127-
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder();
141+
QueryRewriteContext queryRewriteContext = new QueryRewriteContext(parserConfig(), null, () -> 0L);
142+
RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(queryRewriteContext);
128143
boolean compoundAdded = false;
129144
if (randomBoolean()) {
130145
compoundAdded = true;
@@ -136,29 +151,13 @@ public boolean isCompound() {
136151
});
137152
}
138153
SearchSourceBuilder source = new SearchSourceBuilder().retriever(retriever);
139-
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
140-
int size = source.size() < 0 ? DEFAULT_SIZE : source.size();
141-
if (retriever.rankWindowSize < size) {
142-
if (compoundAdded) {
143-
expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext));
144-
}
154+
if (compoundAdded) {
155+
expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext));
145156
} else {
146-
if (compoundAdded) {
147-
expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext));
148-
} else {
149-
SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext);
150-
assertNull(rewrittenSource.retriever());
151-
assertTrue(rewrittenSource.knnSearch().isEmpty());
152-
assertThat(rewrittenSource.query(), instanceOf(RankDocsQueryBuilder.class));
153-
if (rewrittenSource.query() instanceof BoolQueryBuilder) {
154-
BoolQueryBuilder bq = (BoolQueryBuilder) rewrittenSource.query();
155-
assertThat(bq.filter().size(), equalTo(retriever.preFilterQueryBuilders.size()));
156-
assertThat(bq.must().size(), equalTo(1));
157-
assertThat(bq.must().get(0), instanceOf(BoolQueryBuilder.class));
158-
assertThat(bq.should().size(), equalTo(1));
159-
assertThat(bq.should().get(0), instanceOf(RankDocsQueryBuilder.class));
160-
}
161-
}
157+
SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext);
158+
assertNull(rewrittenSource.retriever());
159+
assertTrue(rewrittenSource.knnSearch().isEmpty());
160+
assertThat(rewrittenSource.query(), instanceOf(RankDocsQueryBuilder.class));
162161
}
163162
}
164163
}

0 commit comments

Comments
 (0)