Skip to content

Commit 5e6303c

Browse files
authored
Propagate scoring function through random sampler (#116957) (#117162)
* Propagate scoring function through random sampler. * Update docs/changelog/116957.yaml * Correct score mode in random sampler weight * Fix random sampling with scores and p=1.0 * Unit test with scores * YAML test * Add capability
1 parent 1bc60ac commit 5e6303c

File tree

8 files changed

+150
-55
lines changed

8 files changed

+150
-55
lines changed

docs/changelog/116957.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 116957
2+
summary: Propagate scoring function through random sampler
3+
area: Machine Learning
4+
type: bug
5+
issues: [ 110134 ]

modules/aggregations/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ esplugin {
2020

2121
restResources {
2222
restApi {
23-
include '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
23+
include 'capabilities', '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
2424
}
2525
restTests {
2626
// Pulls in all aggregation tests from core AND the forwards v7's core for forwards compatibility

modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,66 @@ setup:
142142
}
143143
- match: { aggregations.sampled.mean.value: 1.0 }
144144
---
145+
"Test random_sampler aggregation with scored subagg":
146+
- requires:
147+
capabilities:
148+
- method: POST
149+
path: /_search
150+
capabilities: [ random_sampler_with_scored_subaggs ]
151+
test_runner_features: capabilities
152+
reason: "Support for random sampler with scored subaggs capability required"
153+
- do:
154+
search:
155+
index: data
156+
size: 0
157+
body: >
158+
{
159+
"query": {
160+
"function_score": {
161+
"random_score": {}
162+
}
163+
},
164+
"aggs": {
165+
"sampled": {
166+
"random_sampler": {
167+
"probability": 0.5
168+
},
169+
"aggs": {
170+
"top": {
171+
"top_hits": {}
172+
}
173+
}
174+
}
175+
}
176+
}
177+
- is_true: aggregations.sampled.top.hits
178+
- do:
179+
search:
180+
index: data
181+
size: 0
182+
body: >
183+
{
184+
"query": {
185+
"function_score": {
186+
"random_score": {}
187+
}
188+
},
189+
"aggs": {
190+
"sampled": {
191+
"random_sampler": {
192+
"probability": 1.0
193+
},
194+
"aggs": {
195+
"top": {
196+
"top_hits": {}
197+
}
198+
}
199+
}
200+
}
201+
}
202+
- match: { aggregations.sampled.top.hits.total.value: 6 }
203+
- is_true: aggregations.sampled.top.hits.hits.0._score
204+
---
145205
"Test random_sampler aggregation with poor settings":
146206
- requires:
147207
cluster_features: ["gte_v8.2.0"]

server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ private SearchCapabilities() {}
3939
/** Support multi-dense-vector script field access. */
4040
private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
4141

42+
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
43+
4244
public static final Set<String> CAPABILITIES;
4345
static {
4446
HashSet<String> capabilities = new HashSet<>();
@@ -47,6 +49,7 @@ private SearchCapabilities() {}
4749
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
4850
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
4951
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
52+
capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS);
5053
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
5154
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
5255
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);

server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {
4040

4141
protected final String name;
4242
protected final Aggregator parent;
43-
private final AggregationContext context;
43+
protected final AggregationContext context;
4444
private final Map<String, Object> metadata;
4545

4646
protected final Aggregator[] subAggregators;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12+
import org.apache.lucene.search.BooleanClause;
13+
import org.apache.lucene.search.BooleanQuery;
1214
import org.apache.lucene.search.CollectionTerminatedException;
1315
import org.apache.lucene.search.DocIdSetIterator;
16+
import org.apache.lucene.search.Query;
17+
import org.apache.lucene.search.ScoreMode;
1418
import org.apache.lucene.search.Scorer;
1519
import org.apache.lucene.search.Weight;
1620
import org.apache.lucene.util.Bits;
17-
import org.elasticsearch.common.CheckedSupplier;
1821
import org.elasticsearch.common.util.LongArray;
1922
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
2023
import org.elasticsearch.search.aggregations.Aggregator;
2124
import org.elasticsearch.search.aggregations.AggregatorFactories;
2225
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
2326
import org.elasticsearch.search.aggregations.InternalAggregation;
2427
import org.elasticsearch.search.aggregations.LeafBucketCollector;
28+
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
2529
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
2630
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
2731
import org.elasticsearch.search.aggregations.support.AggregationContext;
@@ -34,14 +38,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
3438
private final int seed;
3539
private final Integer shardSeed;
3640
private final double probability;
37-
private final CheckedSupplier<Weight, IOException> weightSupplier;
41+
private Weight weight;
3842

3943
RandomSamplerAggregator(
4044
String name,
4145
int seed,
4246
Integer shardSeed,
4347
double probability,
44-
CheckedSupplier<Weight, IOException> weightSupplier,
4548
AggregatorFactories factories,
4649
AggregationContext context,
4750
Aggregator parent,
@@ -56,10 +59,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
5659
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
5760
);
5861
}
59-
this.weightSupplier = weightSupplier;
6062
this.shardSeed = shardSeed;
6163
}
6264

65+
/**
66+
* This creates the query weight which will be used in the aggregator.
67+
*
68+
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
69+
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
70+
* @return weight to be used, is cached for additional usages
71+
* @throws IOException when building the weight or queries fails;
72+
*/
73+
private Weight getWeight() throws IOException {
74+
if (weight == null) {
75+
ScoreMode scoreMode = scoreMode();
76+
BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add(
77+
context.query(),
78+
scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER
79+
);
80+
if (probability < 1.0) {
81+
Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed);
82+
fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER);
83+
}
84+
weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f);
85+
}
86+
return weight;
87+
}
88+
6389
@Override
6490
public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException {
6591
return buildAggregationsForSingleBucket(
@@ -101,22 +127,26 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
101127
if (sub.isNoop()) {
102128
return LeafBucketCollector.NO_OP_COLLECTOR;
103129
}
130+
131+
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
132+
// This means there are no docs to iterate, possibly due to the fields not existing
133+
if (scorer == null) {
134+
return LeafBucketCollector.NO_OP_COLLECTOR;
135+
}
136+
sub.setScorer(scorer);
137+
104138
// No sampling is being done, collect all docs
139+
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
105140
if (probability >= 1.0) {
106141
grow(1);
107-
return new LeafBucketCollector() {
142+
return new LeafBucketCollectorBase(sub, null) {
108143
@Override
109144
public void collect(int doc, long owningBucketOrd) throws IOException {
110145
collectExistingBucket(sub, doc, 0);
111146
}
112147
};
113148
}
114-
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
115-
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
116-
// This means there are no docs to iterate, possibly due to the fields not existing
117-
if (scorer == null) {
118-
return LeafBucketCollector.NO_OP_COLLECTOR;
119-
}
149+
120150
final DocIdSetIterator docIt = scorer.iterator();
121151
final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs();
122152
try {
@@ -136,5 +166,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
136166
// Since we have done our own collection, there is nothing for the leaf collector to do
137167
return LeafBucketCollector.NO_OP_COLLECTOR;
138168
}
139-
140169
}

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12-
import org.apache.lucene.search.BooleanClause;
13-
import org.apache.lucene.search.BooleanQuery;
14-
import org.apache.lucene.search.ScoreMode;
15-
import org.apache.lucene.search.Weight;
1612
import org.elasticsearch.search.aggregations.Aggregator;
1713
import org.elasticsearch.search.aggregations.AggregatorFactories;
1814
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
3026
private final Integer shardSeed;
3127
private final double probability;
3228
private final SamplingContext samplingContext;
33-
private Weight weight;
3429

3530
RandomSamplerAggregatorFactory(
3631
String name,
@@ -57,41 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
5752
@Override
5853
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
5954
throws IOException {
60-
return new RandomSamplerAggregator(
61-
name,
62-
seed,
63-
shardSeed,
64-
probability,
65-
this::getWeight,
66-
factories,
67-
context,
68-
parent,
69-
cardinality,
70-
metadata
71-
);
55+
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
7256
}
73-
74-
/**
75-
* This creates the query weight which will be used in the aggregator.
76-
*
77-
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
78-
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
79-
* @return weight to be used, is cached for additional usages
80-
* @throws IOException when building the weight or queries fails;
81-
*/
82-
private Weight getWeight() throws IOException {
83-
if (weight == null) {
84-
RandomSamplingQuery query = new RandomSamplingQuery(
85-
probability,
86-
seed,
87-
shardSeed == null ? context.shardRandomSeed() : shardSeed
88-
);
89-
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
90-
.add(context.query(), BooleanClause.Occur.FILTER)
91-
.build();
92-
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
93-
}
94-
return weight;
95-
}
96-
9757
}

server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,29 @@
1111

1212
import org.apache.lucene.document.LongPoint;
1313
import org.apache.lucene.document.SortedNumericDocValuesField;
14+
import org.apache.lucene.index.Term;
15+
import org.apache.lucene.search.BooleanClause;
16+
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.TermQuery;
1418
import org.apache.lucene.tests.index.RandomIndexWriter;
1519
import org.apache.lucene.util.BytesRef;
1620
import org.elasticsearch.common.Strings;
1721
import org.elasticsearch.index.mapper.KeywordFieldMapper;
1822
import org.elasticsearch.index.query.QueryBuilders;
23+
import org.elasticsearch.search.SearchHit;
1924
import org.elasticsearch.search.aggregations.AggregationBuilders;
2025
import org.elasticsearch.search.aggregations.AggregatorTestCase;
2126
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
2227
import org.elasticsearch.search.aggregations.metrics.Avg;
2328
import org.elasticsearch.search.aggregations.metrics.Max;
2429
import org.elasticsearch.search.aggregations.metrics.Min;
30+
import org.elasticsearch.search.aggregations.metrics.TopHits;
2531
import org.hamcrest.Description;
2632
import org.hamcrest.Matcher;
2733
import org.hamcrest.TypeSafeMatcher;
2834

2935
import java.io.IOException;
36+
import java.util.Arrays;
3037
import java.util.List;
3138
import java.util.concurrent.atomic.AtomicInteger;
3239
import java.util.stream.DoubleStream;
@@ -37,6 +44,8 @@
3744
import static org.hamcrest.Matchers.equalTo;
3845
import static org.hamcrest.Matchers.greaterThan;
3946
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
47+
import static org.hamcrest.Matchers.hasSize;
48+
import static org.hamcrest.Matchers.lessThan;
4049
import static org.hamcrest.Matchers.lessThanOrEqualTo;
4150
import static org.hamcrest.Matchers.not;
4251
import static org.hamcrest.Matchers.notANumber;
@@ -76,6 +85,35 @@ public void testAggregationSampling() throws IOException {
7685
assertThat(avgAvg, closeTo(1.5, 0.5));
7786
}
7887

88+
public void testAggregationSampling_withScores() throws IOException {
89+
long[] counts = new long[5];
90+
AtomicInteger integer = new AtomicInteger();
91+
do {
92+
testCase(RandomSamplerAggregatorTests::writeTestDocs, (InternalRandomSampler result) -> {
93+
counts[integer.get()] = result.getDocCount();
94+
if (result.getDocCount() > 0) {
95+
TopHits agg = result.getAggregations().get("top");
96+
List<SearchHit> hits = Arrays.asList(agg.getHits().getHits());
97+
assertThat(Strings.toString(result), hits, hasSize(1));
98+
assertThat(Strings.toString(result), hits.get(0).getScore(), allOf(greaterThan(0.0f), lessThan(1.0f)));
99+
}
100+
},
101+
new AggTestConfig(
102+
new RandomSamplerAggregationBuilder("my_agg").subAggregation(AggregationBuilders.topHits("top").size(1))
103+
.setProbability(0.25),
104+
longField(NUMERIC_FIELD_NAME)
105+
).withQuery(
106+
new BooleanQuery.Builder().add(
107+
new TermQuery(new Term(KEYWORD_FIELD_NAME, KEYWORD_FIELD_VALUE)),
108+
BooleanClause.Occur.SHOULD
109+
).build()
110+
)
111+
);
112+
} while (integer.incrementAndGet() < 5);
113+
long avgCount = LongStream.of(counts).sum() / integer.get();
114+
assertThat(avgCount, allOf(greaterThanOrEqualTo(20L), lessThanOrEqualTo(70L)));
115+
}
116+
79117
public void testAggregationSamplingNestedAggsScaled() throws IOException {
80118
// in case 0 docs get sampled, which can rarely happen
81119
// in case the test index has many segments.

0 commit comments

Comments
 (0)