Skip to content

Commit 81cdffd

Browse files
committed
Correct score mode in random sampler weight
1 parent be60ad8 commit 81cdffd

File tree

3 files changed

+31
-46
lines changed

3 files changed

+31
-46
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {
4040

4141
protected final String name;
4242
protected final Aggregator parent;
43-
private final AggregationContext context;
43+
protected final AggregationContext context;
4444
private final Map<String, Object> metadata;
4545

4646
protected final Aggregator[] subAggregators;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12+
import org.apache.lucene.search.BooleanClause;
13+
import org.apache.lucene.search.BooleanQuery;
1214
import org.apache.lucene.search.CollectionTerminatedException;
1315
import org.apache.lucene.search.DocIdSetIterator;
16+
import org.apache.lucene.search.ScoreMode;
1417
import org.apache.lucene.search.Scorer;
1518
import org.apache.lucene.search.Weight;
1619
import org.apache.lucene.util.Bits;
17-
import org.elasticsearch.common.CheckedSupplier;
1820
import org.elasticsearch.common.util.LongArray;
1921
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
2022
import org.elasticsearch.search.aggregations.Aggregator;
@@ -34,14 +36,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
3436
private final int seed;
3537
private final Integer shardSeed;
3638
private final double probability;
37-
private final CheckedSupplier<Weight, IOException> weightSupplier;
39+
private Weight weight;
3840

3941
RandomSamplerAggregator(
4042
String name,
4143
int seed,
4244
Integer shardSeed,
4345
double probability,
44-
CheckedSupplier<Weight, IOException> weightSupplier,
4546
AggregatorFactories factories,
4647
AggregationContext context,
4748
Aggregator parent,
@@ -56,10 +57,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
5657
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
5758
);
5859
}
59-
this.weightSupplier = weightSupplier;
6060
this.shardSeed = shardSeed;
6161
}
6262

63+
/**
64+
* This creates the query weight which will be used in the aggregator.
65+
*
66+
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
67+
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
68+
* @return weight to be used, is cached for additional usages
69+
* @throws IOException when building the weight or queries fails;
70+
*/
71+
private Weight getWeight() throws IOException {
72+
if (weight == null) {
73+
RandomSamplingQuery query = new RandomSamplingQuery(
74+
probability,
75+
seed,
76+
shardSeed == null ? context.shardRandomSeed() : shardSeed
77+
);
78+
ScoreMode scoreMode = scoreMode();
79+
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
80+
.add(context.query(), scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER)
81+
.build();
82+
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), scoreMode, 1f);
83+
}
84+
return weight;
85+
}
86+
6387
@Override
6488
public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException {
6589
return buildAggregationsForSingleBucket(
@@ -112,7 +136,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
112136
};
113137
}
114138
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
115-
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
139+
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
116140
// This means there are no docs to iterate, possibly due to the fields not existing
117141
if (scorer == null) {
118142
return LeafBucketCollector.NO_OP_COLLECTOR;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12-
import org.apache.lucene.search.BooleanClause;
13-
import org.apache.lucene.search.BooleanQuery;
14-
import org.apache.lucene.search.ScoreMode;
15-
import org.apache.lucene.search.Weight;
1612
import org.elasticsearch.search.aggregations.Aggregator;
1713
import org.elasticsearch.search.aggregations.AggregatorFactories;
1814
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
3026
private final Integer shardSeed;
3127
private final double probability;
3228
private final SamplingContext samplingContext;
33-
private Weight weight;
3429

3530
RandomSamplerAggregatorFactory(
3631
String name,
@@ -57,40 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
5752
@Override
5853
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
5954
throws IOException {
60-
return new RandomSamplerAggregator(
61-
name,
62-
seed,
63-
shardSeed,
64-
probability,
65-
this::getWeight,
66-
factories,
67-
context,
68-
parent,
69-
cardinality,
70-
metadata
71-
);
72-
}
73-
74-
/**
75-
* This creates the query weight which will be used in the aggregator.
76-
*
77-
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
78-
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
79-
* @return weight to be used, is cached for additional usages
80-
* @throws IOException when building the weight or queries fails;
81-
*/
82-
private Weight getWeight() throws IOException {
83-
if (weight == null) {
84-
RandomSamplingQuery query = new RandomSamplingQuery(
85-
probability,
86-
seed,
87-
shardSeed == null ? context.shardRandomSeed() : shardSeed
88-
);
89-
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
90-
.add(context.query(), BooleanClause.Occur.MUST)
91-
.build();
92-
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE, 1f);
93-
}
94-
return weight;
55+
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
9556
}
9657
}

0 commit comments

Comments
 (0)