Skip to content

Commit b19d2e4

Browse files
markjhoyelasticsearchmachine
andauthored
Fix Bug in RankDocRetrieverBuilder when from is set to Default (-1) (elastic#137637) (elastic#137708)
* correct `from` in extractToSearchSourceBuilder * Update docs/changelog/137637.yaml * fix tests; simplify from check * revert auto-formatting * fix failing test; additional tests; changelog area * remove comment * [CI] Auto commit changes from spotless * cleanup unit and YAML tests * [CI] Auto commit changes from spotless * remove old, commented code --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 4ad0ef0 commit b19d2e4

File tree

5 files changed

+98
-56
lines changed

5 files changed

+98
-56
lines changed

docs/changelog/137637.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137637
2+
summary: Fix Bug in `RankDocRetrieverBuilder` when `from` is set to Default (-1)
3+
area: Search
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.index.query.QueryBuilder;
1414
import org.elasticsearch.index.query.QueryRewriteContext;
1515
import org.elasticsearch.index.query.RankDocsQueryBuilder;
16+
import org.elasticsearch.search.SearchService;
1617
import org.elasticsearch.search.builder.SearchSourceBuilder;
1718
import org.elasticsearch.search.rank.RankDoc;
1819
import org.elasticsearch.xcontent.XContentBuilder;
@@ -135,6 +136,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
135136
if (sourceHasMinScore()) {
136137
searchSourceBuilder.minScore(this.minScore == null ? Float.MIN_VALUE : this.minScore);
137138
}
139+
140+
if (searchSourceBuilder.from() < 0) {
141+
searchSourceBuilder.from(SearchService.DEFAULT_FROM);
142+
}
143+
138144
if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) {
139145
searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from()));
140146
}

server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.index.query.RandomQueryBuilder;
1616
import org.elasticsearch.index.query.RankDocsQueryBuilder;
1717
import org.elasticsearch.index.query.Rewriteable;
18+
import org.elasticsearch.search.SearchService;
1819
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
1920
import org.elasticsearch.search.builder.SearchSourceBuilder;
2021
import org.elasticsearch.search.rank.RankDoc;
@@ -124,6 +125,9 @@ public void testExtractToSearchSourceBuilder() throws IOException {
124125
}
125126
}
126127
assertNull(source.postFilter());
128+
129+
// the default `from` is -1, when `extractToSearchSourceBuilder` is run, it should modify this to the default
130+
assertEquals(SearchService.DEFAULT_FROM, source.from());
127131
}
128132

129133
public void testTopDocsQuery() throws IOException {

x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ setup:
398398
rank_window_size: 10
399399
inference_id: my-rerank-model
400400
inference_text: "How often does the moon hide the sun?"
401-
field: inference_text_field
401+
field: text
402402
size: 10
403403

404404
- match: { hits.total.value: 1 }
@@ -477,7 +477,7 @@ setup:
477477
rank_window_size: 10
478478
inference_id: my-rerank-model
479479
inference_text: "How often does the moon hide the sun?"
480-
field: inference_text_field
480+
field: text
481481
min_score: 0
482482
size: 10
483483

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 81 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import static org.hamcrest.Matchers.containsString;
5353
import static org.hamcrest.Matchers.equalTo;
5454
import static org.hamcrest.Matchers.instanceOf;
55-
import static org.hamcrest.Matchers.lessThanOrEqualTo;
5655

5756
@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
5857
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
@@ -161,63 +160,91 @@ public void testRRFPagination() {
161160
for (int i = 0; i < randomIntBetween(1, 5); i++) {
162161
int from = randomIntBetween(0, totalDocs - 1);
163162
int size = randomIntBetween(1, totalDocs - from);
164-
for (int docs_to_fetch = from; docs_to_fetch < totalDocs; docs_to_fetch += size) {
163+
for (int from_value = from; from_value < totalDocs; from_value += size) {
165164
SearchSourceBuilder source = new SearchSourceBuilder();
166-
source.from(docs_to_fetch);
165+
source.from(from_value);
167166
source.size(size);
168-
// this one retrieves docs 1, 2, 4, 6, and 7
169-
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
170-
QueryBuilders.boolQuery()
171-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
172-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
173-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
174-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
175-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
176-
);
177-
// this one retrieves docs 2 and 6 due to prefilter
178-
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
179-
QueryBuilders.boolQuery()
180-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
181-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
182-
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
183-
);
184-
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
185-
// this one retrieves docs 2, 3, 6, and 7
186-
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
187-
VECTOR_FIELD,
188-
new float[] { 2.0f },
189-
null,
190-
10,
191-
100,
192-
null,
193-
null,
194-
null
195-
);
196-
source.retriever(
197-
new RRFRetrieverBuilder(
198-
Arrays.asList(
199-
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
200-
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
201-
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
202-
),
203-
rankWindowSize,
204-
rankConstant
205-
)
206-
);
207-
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
208-
int fDocs_to_fetch = docs_to_fetch;
209-
ElasticsearchAssertions.assertResponse(req, resp -> {
210-
assertNull(resp.pointInTimeId());
211-
assertNotNull(resp.getHits().getTotalHits());
212-
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
213-
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
214-
assertThat(resp.getHits().getHits().length, lessThanOrEqualTo(size));
215-
for (int k = 0; k < Math.min(size, resp.getHits().getHits().length); k++) {
216-
assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + fDocs_to_fetch)));
217-
}
218-
});
167+
assertRRFPagination(source, from_value, size, rankWindowSize, rankConstant, expectedDocIds);
219168
}
220169
}
170+
171+
// test with `from` as the default (-1)
172+
for (int i = 0; i < randomIntBetween(5, 20); i++) {
173+
int size = randomIntBetween(1, totalDocs);
174+
SearchSourceBuilder source = new SearchSourceBuilder();
175+
source.size(size);
176+
assertRRFPagination(source, source.from(), size, rankWindowSize, rankConstant, expectedDocIds);
177+
}
178+
179+
// and finally test with from = default, and size > {total docs} to be sure
180+
SearchSourceBuilder source = new SearchSourceBuilder();
181+
source.size(totalDocs + 2);
182+
assertRRFPagination(source, source.from(), totalDocs, rankWindowSize, rankConstant, expectedDocIds);
183+
}
184+
185+
private void assertRRFPagination(
186+
SearchSourceBuilder source,
187+
int from,
188+
int maxExpectedSize,
189+
int rankWindowSize,
190+
int rankConstant,
191+
List<String> expectedDocIds
192+
) {
193+
// this one retrieves docs 1, 2, 4, 6, and 7
194+
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
195+
QueryBuilders.boolQuery()
196+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L))
197+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L))
198+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L))
199+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L))
200+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L))
201+
);
202+
// this one retrieves docs 2 and 6 due to prefilter
203+
StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(
204+
QueryBuilders.boolQuery()
205+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L))
206+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L))
207+
.should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L))
208+
);
209+
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
210+
// this one retrieves docs 2, 3, 6, and 7
211+
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
212+
VECTOR_FIELD,
213+
new float[] { 2.0f },
214+
null,
215+
10,
216+
100,
217+
null,
218+
null,
219+
null
220+
);
221+
source.retriever(
222+
new RRFRetrieverBuilder(
223+
Arrays.asList(
224+
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
225+
new CompoundRetrieverBuilder.RetrieverSource(standard1, null),
226+
new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null)
227+
),
228+
rankWindowSize,
229+
rankConstant
230+
)
231+
);
232+
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
233+
234+
int innerFrom = Math.max(from, 0);
235+
ElasticsearchAssertions.assertResponse(req, resp -> {
236+
assertNull(resp.pointInTimeId());
237+
assertNotNull(resp.getHits().getTotalHits());
238+
assertThat(resp.getHits().getTotalHits().value(), equalTo(6L));
239+
assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO));
240+
241+
int expectedSize = innerFrom + maxExpectedSize > 6 ? 6 - innerFrom : maxExpectedSize;
242+
assertThat(resp.getHits().getHits().length, equalTo(expectedSize));
243+
244+
for (int k = 0; k < expectedSize; k++) {
245+
assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + innerFrom)));
246+
}
247+
});
221248
}
222249

223250
public void testRRFWithAggs() {

0 commit comments

Comments
 (0)