Skip to content

Commit 4cf9b0c

Browse files
authored
Enforce no-op setWeight on collectors that filter docs out (#94886)
MinimumScoreCollector and FilteredCollector filter documents out as part of their collection. The have an inner collector to delegate to, but they should never propagate the Weight to them otherwise the total hit count may not reflect the filtering. This commit clarifies this through an empty final setWeight method on both collectors and additional javadocs.
1 parent b546d70 commit 4cf9b0c

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

server/src/main/java/org/elasticsearch/common/lucene/MinimumScoreCollector.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515
import org.apache.lucene.search.ScoreCachingWrappingScorer;
1616
import org.apache.lucene.search.ScoreMode;
1717
import org.apache.lucene.search.SimpleCollector;
18+
import org.apache.lucene.search.Weight;
1819

1920
import java.io.IOException;
2021

22+
/**
23+
* Collector that wraps another collector and collects only documents that have a score that's greater or equal than the
24+
* provided minimum score. Given that this collector filters documents out, it must not propagate the {@link Weight} to its
25+
* inner collector, as that may lead to exposing total hit count that does not reflect the filtering.
26+
*/
2127
public class MinimumScoreCollector extends SimpleCollector {
2228

2329
private final Collector collector;
@@ -31,6 +37,12 @@ public MinimumScoreCollector(Collector collector, float minimumScore) {
3137
this.minimumScore = minimumScore;
3238
}
3339

40+
@Override
41+
public final void setWeight(Weight weight) {
42+
// no-op: this collector filters documents out hence it must not propagate the weight to its inner collector,
43+
// otherwise the total hit count may not reflect the filtering
44+
}
45+
3446
@Override
3547
public void setScorer(Scorable scorer) throws IOException {
3648
if ((scorer instanceof ScoreCachingWrappingScorer) == false) {

server/src/main/java/org/elasticsearch/common/lucene/search/FilteredCollector.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
import java.io.IOException;
2121

22+
/**
23+
* Collector that wraps another collector and collects only documents that match the provided filter.
24+
* Given that this collector filters documents out, it must not propagate the {@link Weight} to its
25+
* inner collector, as that may lead to exposing total hit count that does not reflect the filtering.
26+
*/
2227
public class FilteredCollector implements Collector {
2328

2429
private final Collector collector;
@@ -29,6 +34,12 @@ public FilteredCollector(Collector collector, Weight filter) {
2934
this.filter = filter;
3035
}
3136

37+
@Override
38+
public final void setWeight(Weight weight) {
39+
// no-op: this collector filters documents out hence it must not propagate the weight to its inner collector,
40+
// otherwise the total hit count may not reflect the filtering
41+
}
42+
3243
@Override
3344
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
3445
final ScorerSupplier filterScorerSupplier = filter.scorerSupplier(context);
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.common.lucene;
10+
11+
import org.apache.lucene.document.Document;
12+
import org.apache.lucene.document.Field;
13+
import org.apache.lucene.document.StringField;
14+
import org.apache.lucene.index.IndexReader;
15+
import org.apache.lucene.index.Term;
16+
import org.apache.lucene.search.BooleanClause;
17+
import org.apache.lucene.search.BooleanQuery;
18+
import org.apache.lucene.search.BoostQuery;
19+
import org.apache.lucene.search.IndexSearcher;
20+
import org.apache.lucene.search.MatchAllDocsQuery;
21+
import org.apache.lucene.search.TermQuery;
22+
import org.apache.lucene.search.TopDocs;
23+
import org.apache.lucene.search.TopScoreDocCollector;
24+
import org.apache.lucene.search.TotalHitCountCollector;
25+
import org.apache.lucene.store.Directory;
26+
import org.apache.lucene.tests.index.RandomIndexWriter;
27+
import org.elasticsearch.core.IOUtils;
28+
import org.elasticsearch.test.ESTestCase;
29+
30+
import java.io.IOException;
31+
32+
public class MinimumScoreCollectorTests extends ESTestCase {
33+
34+
private Directory directory;
35+
private IndexReader reader;
36+
private IndexSearcher searcher;
37+
private int numDocs;
38+
39+
@Override
40+
public void setUp() throws Exception {
41+
super.setUp();
42+
directory = newDirectory();
43+
RandomIndexWriter writer = new RandomIndexWriter(random(), directory, newIndexWriterConfig());
44+
numDocs = randomIntBetween(10, 100);
45+
for (int i = 0; i < numDocs; i++) {
46+
Document doc = new Document();
47+
doc.add(new StringField("field1", "value", Field.Store.NO));
48+
if (i == 0) {
49+
doc.add(new StringField("field2", "value", Field.Store.NO));
50+
}
51+
writer.addDocument(doc);
52+
}
53+
writer.flush();
54+
reader = writer.getReader();
55+
searcher = newSearcher(reader);
56+
writer.close();
57+
}
58+
59+
@Override
60+
public void tearDown() throws Exception {
61+
super.tearDown();
62+
IOUtils.close(reader, directory);
63+
}
64+
65+
public void testMinScoreFiltering() throws IOException {
66+
float maxScore;
67+
float thresholdScore;
68+
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(new TermQuery(new Term("field1", "value")), BooleanClause.Occur.MUST)
69+
.add(new BoostQuery(new TermQuery(new Term("field2", "value")), 200f), BooleanClause.Occur.SHOULD)
70+
.build();
71+
{
72+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(2, 100);
73+
searcher.search(booleanQuery, topScoreDocCollector);
74+
TopDocs topDocs = topScoreDocCollector.topDocs();
75+
assertEquals(numDocs, topDocs.totalHits.value);
76+
maxScore = topDocs.scoreDocs[0].score;
77+
thresholdScore = topDocs.scoreDocs[1].score;
78+
}
79+
{
80+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
81+
searcher.search(booleanQuery, new MinimumScoreCollector(topScoreDocCollector, maxScore));
82+
assertEquals(1, topScoreDocCollector.topDocs().totalHits.value);
83+
}
84+
{
85+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
86+
searcher.search(booleanQuery, new MinimumScoreCollector(topScoreDocCollector, thresholdScore));
87+
assertEquals(numDocs, topScoreDocCollector.topDocs().totalHits.value);
88+
}
89+
{
90+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
91+
searcher.search(booleanQuery, new MinimumScoreCollector(topScoreDocCollector, maxScore + 100f));
92+
assertEquals(0, topScoreDocCollector.topDocs().totalHits.value);
93+
}
94+
}
95+
96+
public void testWeightIsNotPropagated() throws IOException {
97+
{
98+
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
99+
searcher.search(new MatchAllDocsQuery(), totalHitCountCollector);
100+
assertEquals(reader.maxDoc(), totalHitCountCollector.getTotalHits());
101+
}
102+
{
103+
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
104+
searcher.search(new MatchAllDocsQuery(), new MinimumScoreCollector(totalHitCountCollector, 100f));
105+
assertEquals(0, totalHitCountCollector.getTotalHits());
106+
}
107+
}
108+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.common.lucene.search;
10+
11+
import org.apache.lucene.document.Document;
12+
import org.apache.lucene.document.Field;
13+
import org.apache.lucene.document.StringField;
14+
import org.apache.lucene.index.IndexReader;
15+
import org.apache.lucene.index.Term;
16+
import org.apache.lucene.search.IndexSearcher;
17+
import org.apache.lucene.search.MatchAllDocsQuery;
18+
import org.apache.lucene.search.ScoreMode;
19+
import org.apache.lucene.search.TermQuery;
20+
import org.apache.lucene.search.TopScoreDocCollector;
21+
import org.apache.lucene.search.TotalHitCountCollector;
22+
import org.apache.lucene.search.Weight;
23+
import org.apache.lucene.store.Directory;
24+
import org.apache.lucene.tests.index.RandomIndexWriter;
25+
import org.elasticsearch.core.IOUtils;
26+
import org.elasticsearch.test.ESTestCase;
27+
28+
import java.io.IOException;
29+
30+
public class FilteredCollectorTests extends ESTestCase {
31+
32+
private Directory directory;
33+
private IndexReader reader;
34+
private IndexSearcher searcher;
35+
private int numDocs;
36+
37+
@Override
38+
public void setUp() throws Exception {
39+
super.setUp();
40+
directory = newDirectory();
41+
RandomIndexWriter writer = new RandomIndexWriter(random(), directory, newIndexWriterConfig());
42+
numDocs = randomIntBetween(10, 100);
43+
for (int i = 0; i < numDocs; i++) {
44+
Document doc = new Document();
45+
doc.add(new StringField("field1", "value", Field.Store.NO));
46+
if (i == 0) {
47+
doc.add(new StringField("field2", "value", Field.Store.NO));
48+
}
49+
writer.addDocument(doc);
50+
}
51+
writer.flush();
52+
reader = writer.getReader();
53+
searcher = newSearcher(reader);
54+
writer.close();
55+
}
56+
57+
@Override
58+
public void tearDown() throws Exception {
59+
super.tearDown();
60+
IOUtils.close(reader, directory);
61+
}
62+
63+
public void testFiltering() throws IOException {
64+
{
65+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
66+
searcher.search(new MatchAllDocsQuery(), topScoreDocCollector);
67+
assertEquals(numDocs, topScoreDocCollector.topDocs().totalHits.value);
68+
}
69+
{
70+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
71+
TermQuery termQuery = new TermQuery(new Term("field2", "value"));
72+
Weight filterWeight = termQuery.createWeight(searcher, ScoreMode.TOP_DOCS, 1f);
73+
searcher.search(new MatchAllDocsQuery(), new FilteredCollector(topScoreDocCollector, filterWeight));
74+
assertEquals(1, topScoreDocCollector.topDocs().totalHits.value);
75+
}
76+
{
77+
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(1, 100);
78+
TermQuery termQuery = new TermQuery(new Term("field1", "value"));
79+
Weight filterWeight = termQuery.createWeight(searcher, ScoreMode.TOP_DOCS, 1f);
80+
searcher.search(new MatchAllDocsQuery(), new FilteredCollector(topScoreDocCollector, filterWeight));
81+
assertEquals(numDocs, topScoreDocCollector.topDocs().totalHits.value);
82+
}
83+
}
84+
85+
public void testWeightIsNotPropagated() throws IOException {
86+
{
87+
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
88+
searcher.search(new MatchAllDocsQuery(), totalHitCountCollector);
89+
assertEquals(reader.maxDoc(), totalHitCountCollector.getTotalHits());
90+
}
91+
{
92+
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
93+
TermQuery termQuery = new TermQuery(new Term("field2", "value"));
94+
Weight filterWeight = termQuery.createWeight(searcher, ScoreMode.TOP_DOCS, 1f);
95+
searcher.search(new MatchAllDocsQuery(), new FilteredCollector(totalHitCountCollector, filterWeight));
96+
assertEquals(1, totalHitCountCollector.getTotalHits());
97+
}
98+
}
99+
}

0 commit comments

Comments
 (0)