Skip to content

Commit cfe449a

Browse files
committed
Fix random sampling with scores and p=1.0
1 parent 81cdffd commit cfe449a

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.search.BooleanQuery;
1414
import org.apache.lucene.search.CollectionTerminatedException;
1515
import org.apache.lucene.search.DocIdSetIterator;
16+
import org.apache.lucene.search.Query;
1617
import org.apache.lucene.search.ScoreMode;
1718
import org.apache.lucene.search.Scorer;
1819
import org.apache.lucene.search.Weight;
@@ -24,6 +25,7 @@
2425
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
2526
import org.elasticsearch.search.aggregations.InternalAggregation;
2627
import org.elasticsearch.search.aggregations.LeafBucketCollector;
28+
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
2729
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
2830
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
2931
import org.elasticsearch.search.aggregations.support.AggregationContext;
@@ -70,16 +72,16 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
7072
*/
7173
private Weight getWeight() throws IOException {
7274
if (weight == null) {
73-
RandomSamplingQuery query = new RandomSamplingQuery(
74-
probability,
75-
seed,
76-
shardSeed == null ? context.shardRandomSeed() : shardSeed
77-
);
7875
ScoreMode scoreMode = scoreMode();
79-
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
80-
.add(context.query(), scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER)
81-
.build();
82-
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), scoreMode, 1f);
76+
BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add(
77+
context.query(),
78+
scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER
79+
);
80+
if (probability < 1.0) {
81+
Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed);
82+
fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER);
83+
}
84+
weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f);
8385
}
8486
return weight;
8587
}
@@ -125,23 +127,25 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
125127
if (sub.isNoop()) {
126128
return LeafBucketCollector.NO_OP_COLLECTOR;
127129
}
130+
131+
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
132+
// This means there are no docs to iterate, possibly due to the fields not existing
133+
if (scorer == null) {
134+
return LeafBucketCollector.NO_OP_COLLECTOR;
135+
}
136+
sub.setScorer(scorer);
137+
128138
// No sampling is being done, collect all docs
139+
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
129140
if (probability >= 1.0) {
130141
grow(1);
131-
return new LeafBucketCollector() {
142+
return new LeafBucketCollectorBase(sub, null) {
132143
@Override
133144
public void collect(int doc, long owningBucketOrd) throws IOException {
134145
collectExistingBucket(sub, doc, 0);
135146
}
136147
};
137148
}
138-
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
139-
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
140-
// This means there are no docs to iterate, possibly due to the fields not existing
141-
if (scorer == null) {
142-
return LeafBucketCollector.NO_OP_COLLECTOR;
143-
}
144-
sub.setScorer(scorer);
145149

146150
final DocIdSetIterator docIt = scorer.iterator();
147151
final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs();
@@ -162,5 +166,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
162166
// Since we have done our own collection, there is nothing for the leaf collector to do
163167
return LeafBucketCollector.NO_OP_COLLECTOR;
164168
}
165-
166169
}

0 commit comments

Comments
 (0)