|
52 | 52 | import static org.hamcrest.Matchers.containsString; |
53 | 53 | import static org.hamcrest.Matchers.equalTo; |
54 | 54 | import static org.hamcrest.Matchers.instanceOf; |
55 | | -import static org.hamcrest.Matchers.lessThanOrEqualTo; |
56 | 55 |
|
57 | 56 | @ESIntegTestCase.ClusterScope(minNumDataNodes = 3) |
58 | 57 | public class RRFRetrieverBuilderIT extends ESIntegTestCase { |
@@ -161,63 +160,91 @@ public void testRRFPagination() { |
161 | 160 | for (int i = 0; i < randomIntBetween(1, 5); i++) { |
162 | 161 | int from = randomIntBetween(0, totalDocs - 1); |
163 | 162 | int size = randomIntBetween(1, totalDocs - from); |
164 | | - for (int docs_to_fetch = from; docs_to_fetch < totalDocs; docs_to_fetch += size) { |
| 163 | + for (int from_value = from; from_value < totalDocs; from_value += size) { |
165 | 164 | SearchSourceBuilder source = new SearchSourceBuilder(); |
166 | | - source.from(docs_to_fetch); |
| 165 | + source.from(from_value); |
167 | 166 | source.size(size); |
168 | | - // this one retrieves docs 1, 2, 4, 6, and 7 |
169 | | - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( |
170 | | - QueryBuilders.boolQuery() |
171 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) |
172 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) |
173 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) |
174 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) |
175 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) |
176 | | - ); |
177 | | - // this one retrieves docs 2 and 6 due to prefilter |
178 | | - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( |
179 | | - QueryBuilders.boolQuery() |
180 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) |
181 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) |
182 | | - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) |
183 | | - ); |
184 | | - standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); |
185 | | - // this one retrieves docs 2, 3, 6, and 7 |
186 | | - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( |
187 | | - VECTOR_FIELD, |
188 | | - new float[] { 2.0f }, |
189 | | - null, |
190 | | - 10, |
191 | | - 100, |
192 | | - null, |
193 | | - null, |
194 | | - null |
195 | | - ); |
196 | | - source.retriever( |
197 | | - new RRFRetrieverBuilder( |
198 | | - Arrays.asList( |
199 | | - new CompoundRetrieverBuilder.RetrieverSource(standard0, null), |
200 | | - new CompoundRetrieverBuilder.RetrieverSource(standard1, null), |
201 | | - new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) |
202 | | - ), |
203 | | - rankWindowSize, |
204 | | - rankConstant |
205 | | - ) |
206 | | - ); |
207 | | - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); |
208 | | - int fDocs_to_fetch = docs_to_fetch; |
209 | | - ElasticsearchAssertions.assertResponse(req, resp -> { |
210 | | - assertNull(resp.pointInTimeId()); |
211 | | - assertNotNull(resp.getHits().getTotalHits()); |
212 | | - assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); |
213 | | - assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); |
214 | | - assertThat(resp.getHits().getHits().length, lessThanOrEqualTo(size)); |
215 | | - for (int k = 0; k < Math.min(size, resp.getHits().getHits().length); k++) { |
216 | | - assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + fDocs_to_fetch))); |
217 | | - } |
218 | | - }); |
| 167 | + assertRRFPagination(source, from_value, size, rankWindowSize, rankConstant, expectedDocIds); |
219 | 168 | } |
220 | 169 | } |
| 170 | + |
| 171 | + // test with `from` as the default (-1) |
| 172 | + for (int i = 0; i < randomIntBetween(5, 20); i++) { |
| 173 | + int size = randomIntBetween(1, totalDocs); |
| 174 | + SearchSourceBuilder source = new SearchSourceBuilder(); |
| 175 | + source.size(size); |
| 176 | + assertRRFPagination(source, source.from(), size, rankWindowSize, rankConstant, expectedDocIds); |
| 177 | + } |
| 178 | + |
| 179 | + // and finally test with from = default, and size > {total docs} to be sure |
| 180 | + SearchSourceBuilder source = new SearchSourceBuilder(); |
| 181 | + source.size(totalDocs + 2); |
| 182 | + assertRRFPagination(source, source.from(), totalDocs, rankWindowSize, rankConstant, expectedDocIds); |
| 183 | + } |
| 184 | + |
| 185 | + private void assertRRFPagination( |
| 186 | + SearchSourceBuilder source, |
| 187 | + int from, |
| 188 | + int maxExpectedSize, |
| 189 | + int rankWindowSize, |
| 190 | + int rankConstant, |
| 191 | + List<String> expectedDocIds |
| 192 | + ) { |
| 193 | + // this one retrieves docs 1, 2, 4, 6, and 7 |
| 194 | + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( |
| 195 | + QueryBuilders.boolQuery() |
| 196 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) |
| 197 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) |
| 198 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) |
| 199 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) |
| 200 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) |
| 201 | + ); |
| 202 | + // this one retrieves docs 2 and 6 due to prefilter |
| 203 | + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( |
| 204 | + QueryBuilders.boolQuery() |
| 205 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) |
| 206 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) |
| 207 | + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) |
| 208 | + ); |
| 209 | + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); |
| 210 | + // this one retrieves docs 2, 3, 6, and 7 |
| 211 | + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( |
| 212 | + VECTOR_FIELD, |
| 213 | + new float[] { 2.0f }, |
| 214 | + null, |
| 215 | + 10, |
| 216 | + 100, |
| 217 | + null, |
| 218 | + null, |
| 219 | + null |
| 220 | + ); |
| 221 | + source.retriever( |
| 222 | + new RRFRetrieverBuilder( |
| 223 | + Arrays.asList( |
| 224 | + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), |
| 225 | + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), |
| 226 | + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) |
| 227 | + ), |
| 228 | + rankWindowSize, |
| 229 | + rankConstant |
| 230 | + ) |
| 231 | + ); |
| 232 | + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); |
| 233 | + |
| 234 | + int innerFrom = Math.max(from, 0); |
| 235 | + ElasticsearchAssertions.assertResponse(req, resp -> { |
| 236 | + assertNull(resp.pointInTimeId()); |
| 237 | + assertNotNull(resp.getHits().getTotalHits()); |
| 238 | + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); |
| 239 | + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); |
| 240 | + |
| 241 | + int expectedSize = innerFrom + maxExpectedSize > 6 ? 6 - innerFrom : maxExpectedSize; |
| 242 | + assertThat(resp.getHits().getHits().length, equalTo(expectedSize)); |
| 243 | + |
| 244 | + for (int k = 0; k < expectedSize; k++) { |
| 245 | + assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + innerFrom))); |
| 246 | + } |
| 247 | + }); |
221 | 248 | } |
222 | 249 |
|
223 | 250 | public void testRRFWithAggs() { |
|
0 commit comments