Skip to content

Commit c18034b

Browse files
committed
Minscore is made sure it works at the lower level
1 parent cd2fe03 commit c18034b

File tree

2 files changed

+91
-16
lines changed

2 files changed

+91
-16
lines changed

server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ public RankDocsQuery(
274274
this.minScore = minScore;
275275
}
276276

277-
private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs) {
277+
private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs, float minScore) {
278278
this.docs = docs;
279279
this.topQuery = topQuery;
280280
this.tailQuery = tailQuery;
281281
this.onlyRankDocs = onlyRankDocs;
282-
this.minScore = DEFAULT_MIN_SCORE;
282+
this.minScore = minScore;
283283
}
284284

285285
private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) {
@@ -312,7 +312,11 @@ public RankDoc[] rankDocs() {
312312
@Override
313313
public Query rewrite(IndexSearcher searcher) throws IOException {
314314
if (tailQuery == null) {
315-
return topQuery;
315+
var topRewrite = topQuery.rewrite(searcher);
316+
if (topRewrite != topQuery) {
317+
return new RankDocsQuery(this.docs, topRewrite, null, this.onlyRankDocs, this.minScore);
318+
}
319+
return this;
316320
}
317321
boolean hasChanged = false;
318322
var topRewrite = topQuery.rewrite(searcher);
@@ -323,22 +327,33 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
323327
if (tailRewrite != tailQuery) {
324328
hasChanged = true;
325329
}
326-
return hasChanged ? new RankDocsQuery(docs, topRewrite, tailRewrite, onlyRankDocs) : this;
330+
return hasChanged ? new RankDocsQuery(this.docs, topRewrite, tailRewrite, this.onlyRankDocs, this.minScore) : this;
327331
}
328332

329333
@Override
330334
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
331-
if (tailQuery == null) {
332-
throw new IllegalArgumentException("[tailQuery] should not be null; maybe missing a rewrite?");
335+
Query combinedQuery;
336+
if (onlyRankDocs) {
337+
combinedQuery = topQuery;
338+
} else {
339+
if (tailQuery == null) {
340+
combinedQuery = topQuery;
341+
} else {
342+
var combined = new BooleanQuery.Builder().add(topQuery, BooleanClause.Occur.SHOULD)
343+
.add(tailQuery, BooleanClause.Occur.FILTER)
344+
.build();
345+
combinedQuery = combined;
346+
}
333347
}
334-
var combined = new BooleanQuery.Builder().add(topQuery, onlyRankDocs ? BooleanClause.Occur.MUST : BooleanClause.Occur.SHOULD)
335-
.add(tailQuery, BooleanClause.Occur.FILTER)
336-
.build();
348+
337349
var topWeight = topQuery.createWeight(searcher, scoreMode, boost);
338-
var combinedWeight = searcher.rewrite(combined).createWeight(searcher, scoreMode, boost);
350+
var combinedWeight = searcher.rewrite(combinedQuery).createWeight(searcher, scoreMode, boost);
339351
return new Weight(this) {
340352
@Override
341353
public int count(LeafReaderContext context) throws IOException {
354+
if (onlyRankDocs) {
355+
return topWeight.count(context);
356+
}
342357
return combinedWeight.count(context);
343358
}
344359

@@ -359,22 +374,23 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {
359374

360375
@Override
361376
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
362-
ScorerSupplier supplier = combinedWeight.scorerSupplier(context);
363-
if (minScore != DEFAULT_MIN_SCORE) {
377+
ScorerSupplier baseSupplier = onlyRankDocs ? topWeight.scorerSupplier(context) : combinedWeight.scorerSupplier(context);
378+
379+
if (minScore != DEFAULT_MIN_SCORE && baseSupplier != null) {
364380
return new ScorerSupplier() {
365381
@Override
366382
public Scorer get(long leadCost) throws IOException {
367-
Scorer scorer = supplier.get(leadCost);
368-
return new MinScoreScorer(scorer, minScore);
383+
Scorer scorer = baseSupplier.get(leadCost);
384+
return scorer == null ? null : new MinScoreScorer(scorer, minScore);
369385
}
370386

371387
@Override
372388
public long cost() {
373-
return supplier.cost();
389+
return baseSupplier.cost();
374390
}
375391
};
376392
}
377-
return supplier;
393+
return baseSupplier;
378394
}
379395
};
380396
}

server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.apache.lucene.search.IndexSearcher;
2020
import org.apache.lucene.search.Query;
2121
import org.apache.lucene.search.ScoreDoc;
22+
import org.apache.lucene.search.TopDocs;
2223
import org.apache.lucene.search.TopScoreDocCollectorManager;
2324
import org.apache.lucene.store.Directory;
2425
import org.apache.lucene.tests.index.RandomIndexWriter;
@@ -278,4 +279,62 @@ public void testMinScoreSerializationAndParsing() throws IOException {
278279

279280
assertArrayEquals(rankDocs, parsedBuilder.rankDocs());
280281
}
282+
283+
public void testRankDocsQueryMinScoreFiltering() throws IOException {
284+
try (Directory directory = newDirectory()) {
285+
try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
286+
Document doc0 = new Document();
287+
iw.addDocument(doc0);
288+
Document doc1 = new Document();
289+
iw.addDocument(doc1);
290+
Document doc2 = new Document();
291+
iw.addDocument(doc2);
292+
Document doc3 = new Document();
293+
iw.addDocument(doc3);
294+
}
295+
296+
float minScore = 1.5f;
297+
RankDoc[] rankDocs = new RankDoc[] {
298+
new RankDoc(0, 2.0f, 0),
299+
new RankDoc(1, 1.0f, 0),
300+
new RankDoc(2, 1.6f, 0)
301+
};
302+
Arrays.sort(rankDocs);
303+
for (int i = 0; i < rankDocs.length; i++) {
304+
rankDocs[i].rank = i;
305+
}
306+
RankDocsQueryBuilder builder = new RankDocsQueryBuilder(rankDocs, null, true, minScore);
307+
308+
try (IndexReader reader = DirectoryReader.open(directory)) {
309+
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
310+
Query query = builder.doToQuery(context);
311+
assertTrue("Query should be RankDocsQuery", query instanceof RankDocsQuery);
312+
313+
IndexSearcher searcher = newSearcher(reader);
314+
TopDocs topDocs = searcher.search(query, 10);
315+
316+
long expectedTotalHits = 2;
317+
long expectedFilteredHits = 2;
318+
assertEquals("Total hits should match filtered count", expectedTotalHits, topDocs.totalHits.value());
319+
assertEquals("Number of score docs should match filtered count", expectedFilteredHits, topDocs.scoreDocs.length);
320+
321+
boolean foundDoc0 = false;
322+
boolean foundDoc2 = false;
323+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
324+
assertTrue("Score should be >= minScore", scoreDoc.score >= minScore);
325+
if (scoreDoc.doc == 0) {
326+
assertEquals("Doc 0 score mismatch", 2.0f, scoreDoc.score, 0f);
327+
foundDoc0 = true;
328+
} else if (scoreDoc.doc == 2) {
329+
assertEquals("Doc 2 score mismatch", 1.6f, scoreDoc.score, 0f);
330+
foundDoc2 = true;
331+
} else {
332+
fail("Unexpected document ID returned: " + scoreDoc.doc);
333+
}
334+
}
335+
assertTrue("Document 0 should have been found", foundDoc0);
336+
assertTrue("Document 2 should have been found", foundDoc2);
337+
}
338+
}
339+
}
281340
}

0 commit comments

Comments
 (0)