Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116957.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116957
summary: Propagate scoring function through random sampler
area: Machine Learning
type: bug
issues: [ 110134 ]
2 changes: 1 addition & 1 deletion modules/aggregations/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ esplugin {

restResources {
restApi {
include '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
include 'capabilities', '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
}
restTests {
// Pulls in all aggregation tests from core AND the forwards v7's core for forwards compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,66 @@ setup:
}
- match: { aggregations.sampled.mean.value: 1.0 }
---
"Test random_sampler aggregation with scored subagg":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ random_sampler_with_scored_subaggs ]
test_runner_features: capabilities
reason: "Support for random sampler with scored subaggs capability required"
- do:
search:
index: data
size: 0
body: >
{
"query": {
"function_score": {
"random_score": {}
}
},
"aggs": {
"sampled": {
"random_sampler": {
"probability": 0.5
},
"aggs": {
"top": {
"top_hits": {}
}
}
}
}
}
- is_true: aggregations.sampled.top.hits
- do:
search:
index: data
size: 0
body: >
{
"query": {
"function_score": {
"random_score": {}
}
},
"aggs": {
"sampled": {
"random_sampler": {
"probability": 1.0
},
"aggs": {
"top": {
"top_hits": {}
}
}
}
}
}
- match: { aggregations.sampled.top.hits.total.value: 6 }
- is_true: aggregations.sampled.top.hits.hits.0._score
---
"Test random_sampler aggregation with poor settings":
- requires:
cluster_features: ["gte_v8.2.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ private SearchCapabilities() {}
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";

public static final Set<String> CAPABILITIES = Set.of(
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
NESTED_RETRIEVER_INNER_HITS_SUPPORT
NESTED_RETRIEVER_INNER_HITS_SUPPORT,
RANDOM_SAMPLER_WITH_SCORED_SUBAGGS
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {

protected final String name;
protected final Aggregator parent;
private final AggregationContext context;
protected final AggregationContext context;
private final Map<String, Object> metadata;

protected final Aggregator[] subAggregators;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
Expand All @@ -33,14 +37,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
private final int seed;
private final Integer shardSeed;
private final double probability;
private final CheckedSupplier<Weight, IOException> weightSupplier;
private Weight weight;

RandomSamplerAggregator(
String name,
int seed,
Integer shardSeed,
double probability,
CheckedSupplier<Weight, IOException> weightSupplier,
AggregatorFactories factories,
AggregationContext context,
Aggregator parent,
Expand All @@ -55,10 +58,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
);
}
this.weightSupplier = weightSupplier;
this.shardSeed = shardSeed;
}

/**
* This creates the query weight which will be used in the aggregator.
*
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
* @return weight to be used, is cached for additional usages
* @throws IOException when building the weight or queries fails;
*/
private Weight getWeight() throws IOException {
if (weight == null) {
ScoreMode scoreMode = scoreMode();
BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add(
context.query(),
scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER
);
if (probability < 1.0) {
Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed);
fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER);
}
weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f);
}
return weight;
}

@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
return buildAggregationsForSingleBucket(
Expand Down Expand Up @@ -100,22 +126,26 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
if (sub.isNoop()) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
// This means there are no docs to iterate, possibly due to the fields not existing
if (scorer == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
sub.setScorer(scorer);

// No sampling is being done, collect all docs
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
if (probability >= 1.0) {
grow(1);
return new LeafBucketCollector() {
return new LeafBucketCollectorBase(sub, null) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
collectExistingBucket(sub, doc, 0);
}
};
}
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
// This means there are no docs to iterate, possibly due to the fields not existing
if (scorer == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

final DocIdSetIterator docIt = scorer.iterator();
final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs();
try {
Expand All @@ -135,5 +165,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
// Since we have done our own collection, there is nothing for the leaf collector to do
return LeafBucketCollector.NO_OP_COLLECTOR;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

package org.elasticsearch.search.aggregations.bucket.sampler.random;

import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
Expand All @@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
private final Integer shardSeed;
private final double probability;
private final SamplingContext samplingContext;
private Weight weight;

RandomSamplerAggregatorFactory(
String name,
Expand All @@ -57,41 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
@Override
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
throws IOException {
return new RandomSamplerAggregator(
name,
seed,
shardSeed,
probability,
this::getWeight,
factories,
context,
parent,
cardinality,
metadata
);
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
}

/**
* This creates the query weight which will be used in the aggregator.
*
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
* @return weight to be used, is cached for additional usages
* @throws IOException when building the weight or queries fails;
*/
private Weight getWeight() throws IOException {
if (weight == null) {
RandomSamplingQuery query = new RandomSamplingQuery(
probability,
seed,
shardSeed == null ? context.shardRandomSeed() : shardSeed
);
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
.add(context.query(), BooleanClause.Occur.FILTER)
.build();
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
}
return weight;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,29 @@

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Avg;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.Min;
import org.elasticsearch.search.aggregations.metrics.TopHits;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.DoubleStream;
Expand All @@ -37,6 +44,8 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notANumber;
Expand Down Expand Up @@ -76,6 +85,35 @@ public void testAggregationSampling() throws IOException {
assertThat(avgAvg, closeTo(1.5, 0.5));
}

public void testAggregationSampling_withScores() throws IOException {
long[] counts = new long[5];
AtomicInteger integer = new AtomicInteger();
do {
testCase(RandomSamplerAggregatorTests::writeTestDocs, (InternalRandomSampler result) -> {
counts[integer.get()] = result.getDocCount();
if (result.getDocCount() > 0) {
TopHits agg = result.getAggregations().get("top");
List<SearchHit> hits = Arrays.asList(agg.getHits().getHits());
assertThat(Strings.toString(result), hits, hasSize(1));
assertThat(Strings.toString(result), hits.get(0).getScore(), allOf(greaterThan(0.0f), lessThan(1.0f)));
}
},
new AggTestConfig(
new RandomSamplerAggregationBuilder("my_agg").subAggregation(AggregationBuilders.topHits("top").size(1))
.setProbability(0.25),
longField(NUMERIC_FIELD_NAME)
).withQuery(
new BooleanQuery.Builder().add(
new TermQuery(new Term(KEYWORD_FIELD_NAME, KEYWORD_FIELD_VALUE)),
BooleanClause.Occur.SHOULD
).build()
)
);
} while (integer.incrementAndGet() < 5);
long avgCount = LongStream.of(counts).sum() / integer.get();
assertThat(avgCount, allOf(greaterThanOrEqualTo(20L), lessThanOrEqualTo(70L)));
}

public void testAggregationSamplingNestedAggsScaled() throws IOException {
// in case 0 docs get sampled, which can rarely happen
// in case the test index has many segments.
Expand Down
Loading