Skip to content

Commit b27e007

Browse files
committed
Correct score mode in random sampler weight
1 parent 30ad97b commit b27e007

File tree

3 files changed

+31
-46
lines changed

3 files changed

+31
-46
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {
4040

4141
protected final String name;
4242
protected final Aggregator parent;
43-
private final AggregationContext context;
43+
protected final AggregationContext context;
4444
private final Map<String, Object> metadata;
4545

4646
protected final Aggregator[] subAggregators;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12+
import org.apache.lucene.search.BooleanClause;
13+
import org.apache.lucene.search.BooleanQuery;
1214
import org.apache.lucene.search.CollectionTerminatedException;
1315
import org.apache.lucene.search.DocIdSetIterator;
16+
import org.apache.lucene.search.ScoreMode;
1417
import org.apache.lucene.search.Scorer;
1518
import org.apache.lucene.search.Weight;
1619
import org.apache.lucene.util.Bits;
17-
import org.elasticsearch.common.CheckedSupplier;
1820
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
1921
import org.elasticsearch.search.aggregations.Aggregator;
2022
import org.elasticsearch.search.aggregations.AggregatorFactories;
@@ -33,14 +35,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
3335
private final int seed;
3436
private final Integer shardSeed;
3537
private final double probability;
36-
private final CheckedSupplier<Weight, IOException> weightSupplier;
38+
private Weight weight;
3739

3840
RandomSamplerAggregator(
3941
String name,
4042
int seed,
4143
Integer shardSeed,
4244
double probability,
43-
CheckedSupplier<Weight, IOException> weightSupplier,
4445
AggregatorFactories factories,
4546
AggregationContext context,
4647
Aggregator parent,
@@ -55,10 +56,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
5556
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
5657
);
5758
}
58-
this.weightSupplier = weightSupplier;
5959
this.shardSeed = shardSeed;
6060
}
6161

62+
/**
63+
* This creates the query weight which will be used in the aggregator.
64+
*
65+
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
66+
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
67+
* @return weight to be used, is cached for additional usages
68+
* @throws IOException when building the weight or queries fails;
69+
*/
70+
private Weight getWeight() throws IOException {
71+
if (weight == null) {
72+
RandomSamplingQuery query = new RandomSamplingQuery(
73+
probability,
74+
seed,
75+
shardSeed == null ? context.shardRandomSeed() : shardSeed
76+
);
77+
ScoreMode scoreMode = scoreMode();
78+
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
79+
.add(context.query(), scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER)
80+
.build();
81+
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), scoreMode, 1f);
82+
}
83+
return weight;
84+
}
85+
6286
@Override
6387
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
6488
return buildAggregationsForSingleBucket(
@@ -111,7 +135,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
111135
};
112136
}
113137
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
114-
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
138+
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
115139
// This means there are no docs to iterate, possibly due to the fields not existing
116140
if (scorer == null) {
117141
return LeafBucketCollector.NO_OP_COLLECTOR;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12-
import org.apache.lucene.search.BooleanClause;
13-
import org.apache.lucene.search.BooleanQuery;
14-
import org.apache.lucene.search.ScoreMode;
15-
import org.apache.lucene.search.Weight;
1612
import org.elasticsearch.search.aggregations.Aggregator;
1713
import org.elasticsearch.search.aggregations.AggregatorFactories;
1814
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
3026
private final Integer shardSeed;
3127
private final double probability;
3228
private final SamplingContext samplingContext;
33-
private Weight weight;
3429

3530
RandomSamplerAggregatorFactory(
3631
String name,
@@ -57,40 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
5752
@Override
5853
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
5954
throws IOException {
60-
return new RandomSamplerAggregator(
61-
name,
62-
seed,
63-
shardSeed,
64-
probability,
65-
this::getWeight,
66-
factories,
67-
context,
68-
parent,
69-
cardinality,
70-
metadata
71-
);
72-
}
73-
74-
/**
75-
* This creates the query weight which will be used in the aggregator.
76-
*
77-
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
78-
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
79-
* @return weight to be used, is cached for additional usages
80-
* @throws IOException when building the weight or queries fails;
81-
*/
82-
private Weight getWeight() throws IOException {
83-
if (weight == null) {
84-
RandomSamplingQuery query = new RandomSamplingQuery(
85-
probability,
86-
seed,
87-
shardSeed == null ? context.shardRandomSeed() : shardSeed
88-
);
89-
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
90-
.add(context.query(), BooleanClause.Occur.MUST)
91-
.build();
92-
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE, 1f);
93-
}
94-
return weight;
55+
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
9556
}
9657
}

0 commit comments

Comments
 (0)