Skip to content

Commit 8f55e07

Browse files
committed
correct aggregations for random sampling
1 parent 38441b9 commit 8f55e07

File tree

6 files changed

+86
-2
lines changed

6 files changed

+86
-2
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public abstract class AggregateFunction extends Function implements PostAnalysis
4040
private final Expression field;
4141
private final List<? extends Expression> parameters;
4242
private final Expression filter;
43+
private final boolean isCorrectedForSampling;
4344

4445
protected AggregateFunction(Source source, Expression field) {
4546
this(source, field, Literal.TRUE, emptyList());
@@ -50,10 +51,21 @@ protected AggregateFunction(Source source, Expression field, List<? extends Expr
5051
}
5152

5253
protected AggregateFunction(Source source, Expression field, Expression filter, List<? extends Expression> parameters) {
54+
this(source, field, filter, parameters, false);
55+
}
56+
57+
protected AggregateFunction(
58+
Source source,
59+
Expression field,
60+
Expression filter,
61+
List<? extends Expression> parameters,
62+
boolean isCorrectedForSampling
63+
) {
5364
super(source, CollectionUtils.combine(asList(field, filter), parameters));
5465
this.field = field;
5566
this.filter = filter;
5667
this.parameters = parameters;
68+
this.isCorrectedForSampling = isCorrectedForSampling;
5769
}
5870

5971
protected AggregateFunction(StreamInput in) throws IOException {
@@ -118,6 +130,19 @@ public AggregateFunction withParameters(List<? extends Expression> parameters) {
118130
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(field, filter), parameters));
119131
}
120132

133+
public boolean isCorrectedForSampling() {
134+
return isCorrectedForSampling;
135+
}
136+
137+
/**
138+
* Corrects the aggregation in the context of random sampling. By default,
139+
* nothing is done, but subclasses can override this method if some correction
140+
* is needed. See {@link org.elasticsearch.xpack.esql.expression.function.aggregate.Sum} for an example.
141+
*/
142+
public Expression correctForSampling(Expression samplingProbability) {
143+
return this;
144+
}
145+
121146
@Override
122147
public int hashCode() {
123148
// NB: the hashcode is currently used for key generation so
@@ -131,7 +156,8 @@ public boolean equals(Object obj) {
131156
AggregateFunction other = (AggregateFunction) obj;
132157
return Objects.equals(other.field(), field())
133158
&& Objects.equals(other.filter(), filter())
134-
&& Objects.equals(other.parameters(), parameters());
159+
&& Objects.equals(other.parameters(), parameters())
160+
&& other.isCorrectedForSampling() == isCorrectedForSampling();
135161
}
136162
return false;
137163
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
2828
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
2929
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
30+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3031
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3132
import org.elasticsearch.xpack.esql.planner.ToAggregator;
3233

@@ -97,6 +98,10 @@ public Count(Source source, Expression field, Expression filter) {
9798
super(source, field, filter, emptyList());
9899
}
99100

101+
public Count(Source source, Expression field, Expression filter, boolean isCorrectedForSampling) {
102+
super(source, field, filter, emptyList(), isCorrectedForSampling);
103+
}
104+
100105
private Count(StreamInput in) throws IOException {
101106
super(in);
102107
}
@@ -169,4 +174,9 @@ public Expression surrogate() {
169174

170175
return null;
171176
}
177+
178+
@Override
179+
public Expression correctForSampling(Expression samplingProbability) {
180+
return isCorrectedForSampling() ? this : new Div(source(), new Count(source(), field(), filter(), true), samplingProbability);
181+
}
172182
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ public abstract class NumericAggregate extends AggregateFunction implements ToAg
5353
super(source, field, filter, parameters);
5454
}
5555

56+
NumericAggregate(Source source, Expression field, Expression filter, List<Expression> parameters, boolean isCorrectedForSampling) {
57+
super(source, field, filter, parameters, isCorrectedForSampling);
58+
}
59+
5660
NumericAggregate(Source source, Expression field) {
5761
super(source, field);
5862
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.expression.function.Param;
2828
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
2929
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
30+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3031
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3132

3233
import java.io.IOException;
@@ -65,7 +66,11 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
6566
}
6667

6768
public Sum(Source source, Expression field, Expression filter) {
68-
super(source, field, filter, emptyList());
69+
this(source, field, filter, false);
70+
}
71+
72+
public Sum(Source source, Expression field, Expression filter, boolean isCorrectedForSampling) {
73+
super(source, field, filter, emptyList(), isCorrectedForSampling);
6974
}
7075

7176
private Sum(StreamInput in) throws IOException {
@@ -147,4 +152,9 @@ public Expression surrogate() {
147152
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
148153
: null;
149154
}
155+
156+
@Override
157+
public Expression correctForSampling(Expression samplingProbability) {
158+
return isCorrectedForSampling() ? this : new Div(source(), new Sum(source(), field(), filter(), true), samplingProbability);
159+
}
150160
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateEvalFoldables;
2828
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateInlineEvals;
2929
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateNullable;
30+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateSampleFrequencyToAggs;
3031
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropgateUnmappedFields;
3132
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns;
3233
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneEmptyPlans;
@@ -127,6 +128,7 @@ protected static Batch<LogicalPlan> substitutions() {
127128
"Substitutions",
128129
Limiter.ONCE,
129130
new SubstituteSurrogatePlans(),
131+
new PropagateSampleFrequencyToAggs(),
130132
// Translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
131133
// retrofitted for constant folding - this needs to be fixed.
132134
// Needs to occur before ReplaceAggregateAggExpressionWithEval, which will update the functions, losing the filter.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
12+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
13+
import org.elasticsearch.xpack.esql.plan.logical.RandomSample;
14+
import org.elasticsearch.xpack.esql.rule.Rule;
15+
16+
import java.util.concurrent.atomic.AtomicReference;
17+
18+
public class PropagateSampleFrequencyToAggs extends Rule<LogicalPlan, LogicalPlan> {
19+
@Override
20+
public LogicalPlan apply(LogicalPlan logicalPlan) {
21+
AtomicReference<Expression> sampleProbability = new AtomicReference<>(null);
22+
return logicalPlan.transformUp(plan -> {
23+
if (plan instanceof RandomSample randomSample) {
24+
sampleProbability.set(randomSample.probability());
25+
}
26+
if (sampleProbability.get() != null) {
27+
plan = plan.transformExpressionsOnly(AggregateFunction.class, af -> af.correctForSampling(sampleProbability.get()));
28+
}
29+
return plan;
30+
});
31+
}
32+
}

0 commit comments

Comments
 (0)