Skip to content

Commit 42f3b79

Browse files
committed
Refactor sample correction
1 parent d7b9434 commit 42f3b79

File tree

9 files changed

+88
-42
lines changed

9 files changed

+88
-42
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ public abstract class AggregateFunction extends Function implements PostAnalysis
4040
private final Expression field;
4141
private final List<? extends Expression> parameters;
4242
private final Expression filter;
43-
private final boolean isCorrectedForSampling;
4443

4544
protected AggregateFunction(Source source, Expression field) {
4645
this(source, field, Literal.TRUE, emptyList());
@@ -51,21 +50,10 @@ protected AggregateFunction(Source source, Expression field, List<? extends Expr
5150
}
5251

5352
protected AggregateFunction(Source source, Expression field, Expression filter, List<? extends Expression> parameters) {
54-
this(source, field, filter, parameters, false);
55-
}
56-
57-
protected AggregateFunction(
58-
Source source,
59-
Expression field,
60-
Expression filter,
61-
List<? extends Expression> parameters,
62-
boolean isCorrectedForSampling
63-
) {
6453
super(source, CollectionUtils.combine(asList(field, filter), parameters));
6554
this.field = field;
6655
this.filter = filter;
6756
this.parameters = parameters;
68-
this.isCorrectedForSampling = isCorrectedForSampling;
6957
}
7058

7159
protected AggregateFunction(StreamInput in) throws IOException {
@@ -130,10 +118,6 @@ public AggregateFunction withParameters(List<? extends Expression> parameters) {
130118
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(field, filter), parameters));
131119
}
132120

133-
public boolean isCorrectedForSampling() {
134-
return isCorrectedForSampling;
135-
}
136-
137121
/**
138122
* Corrects the aggregation in the context of random sampling. By default,
139123
* nothing is done, but subclasses can override this method if some correction
@@ -156,8 +140,7 @@ public boolean equals(Object obj) {
156140
AggregateFunction other = (AggregateFunction) obj;
157141
return Objects.equals(other.field(), field())
158142
&& Objects.equals(other.filter(), filter())
159-
&& Objects.equals(other.parameters(), parameters())
160-
&& other.isCorrectedForSampling() == isCorrectedForSampling();
143+
&& Objects.equals(other.parameters(), parameters());
161144
}
162145
return false;
163146
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
2828
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
2929
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
30-
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3130
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3231
import org.elasticsearch.xpack.esql.planner.ToAggregator;
3332

@@ -38,7 +37,7 @@
3837
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
3938
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
4039

41-
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
40+
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression, HasSampleCorrection {
4241
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);
4342

4443
@FunctionInfo(
@@ -98,10 +97,6 @@ public Count(Source source, Expression field, Expression filter) {
9897
super(source, field, filter, emptyList());
9998
}
10099

101-
public Count(Source source, Expression field, Expression filter, boolean isCorrectedForSampling) {
102-
super(source, field, filter, emptyList(), isCorrectedForSampling);
103-
}
104-
105100
private Count(StreamInput in) throws IOException {
106101
super(in);
107102
}
@@ -176,7 +171,7 @@ public Expression surrogate() {
176171
}
177172

178173
@Override
179-
public Expression correctForSampling(Expression samplingProbability) {
180-
return isCorrectedForSampling() ? this : new Div(source(), new Count(source(), field(), filter(), true), samplingProbability);
174+
public Expression sampleCorrection(Expression sampleProbability) {
175+
return new CountSampleCorrection(source(), field(), filter(), sampleProbability);
181176
}
182177
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.core.tree.Source;
12+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
13+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
14+
15+
public class CountSampleCorrection extends Count {
16+
17+
private final Expression sampleProbability;
18+
19+
public CountSampleCorrection(Source source, Expression field, Expression filter, Expression sampleProbability) {
20+
super(source, field, filter);
21+
this.sampleProbability = sampleProbability;
22+
}
23+
24+
@Override
25+
public Expression surrogate() {
26+
return new ToLong(source(), new Div(source(), new Count(source(), field(), filter()), sampleProbability));
27+
}
28+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
12+
public interface HasSampleCorrection {
13+
14+
Expression sampleCorrection(Expression sampleProbability);
15+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ public abstract class NumericAggregate extends AggregateFunction implements ToAg
5353
super(source, field, filter, parameters);
5454
}
5555

56-
NumericAggregate(Source source, Expression field, Expression filter, List<Expression> parameters, boolean isCorrectedForSampling) {
57-
super(source, field, filter, parameters, isCorrectedForSampling);
58-
}
59-
6056
NumericAggregate(Source source, Expression field) {
6157
super(source, field);
6258
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.elasticsearch.xpack.esql.expression.function.Param;
2828
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
2929
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
30-
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3130
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3231

3332
import java.io.IOException;
@@ -44,7 +43,7 @@
4443
/**
4544
* Sum all values of a field in matching documents.
4645
*/
47-
public class Sum extends NumericAggregate implements SurrogateExpression {
46+
public class Sum extends NumericAggregate implements SurrogateExpression, HasSampleCorrection {
4847
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new);
4948

5049
@FunctionInfo(
@@ -66,11 +65,7 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
6665
}
6766

6867
public Sum(Source source, Expression field, Expression filter) {
69-
this(source, field, filter, false);
70-
}
71-
72-
public Sum(Source source, Expression field, Expression filter, boolean isCorrectedForSampling) {
73-
super(source, field, filter, emptyList(), isCorrectedForSampling);
68+
super(source, field, filter, emptyList());
7469
}
7570

7671
private Sum(StreamInput in) throws IOException {
@@ -154,7 +149,7 @@ public Expression surrogate() {
154149
}
155150

156151
@Override
157-
public Expression correctForSampling(Expression samplingProbability) {
158-
return isCorrectedForSampling() ? this : new Div(source(), new Sum(source(), field(), filter(), true), samplingProbability);
152+
public Expression sampleCorrection(Expression sampleProbability) {
153+
return new SumSampleCorrection(source(), field(), filter(), sampleProbability);
159154
}
160155
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.core.tree.Source;
12+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
13+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
14+
15+
public class SumSampleCorrection extends Sum {
16+
17+
private final Expression sampleProbability;
18+
19+
public SumSampleCorrection(Source source, Expression field, Expression filter, Expression sampleProbability) {
20+
super(source, field, filter);
21+
this.sampleProbability = sampleProbability;
22+
}
23+
24+
@Override
25+
public Expression surrogate() {
26+
return switch (dataType()) {
27+
case DOUBLE -> new Div(source(), new Sum(source(), field(), filter()), sampleProbability);
28+
case LONG -> new ToLong(source(), new Div(source(), new Sum(source(), field(), filter()), sampleProbability));
29+
default -> throw new IllegalStateException("unexpected data type [" + dataType() + "]");
30+
};
31+
}
32+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ protected static Batch<LogicalPlan> substitutions() {
127127
return new Batch<>(
128128
"Substitutions",
129129
Limiter.ONCE,
130-
new SubstituteSurrogatePlans(),
131130
new PropagateSampleFrequencyToAggs(),
131+
new SubstituteSurrogatePlans(),
132132
// Translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
133133
// retrofitted for constant folding - this needs to be fixed.
134134
// Needs to occur before ReplaceAggregateAggExpressionWithEval, which will update the functions, losing the filter.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateSampleFrequencyToAggs.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99

1010
import org.elasticsearch.xpack.esql.core.expression.Expression;
1111
import org.elasticsearch.xpack.esql.core.util.Holder;
12-
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
12+
import org.elasticsearch.xpack.esql.expression.function.aggregate.HasSampleCorrection;
1313
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
1414
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1515
import org.elasticsearch.xpack.esql.plan.logical.RandomSample;
1616
import org.elasticsearch.xpack.esql.rule.Rule;
1717

1818
public class PropagateSampleFrequencyToAggs extends Rule<LogicalPlan, LogicalPlan> {
19+
1920
@Override
2021
public LogicalPlan apply(LogicalPlan logicalPlan) {
2122
Holder<Expression> sampleProbability = new Holder<>(null);
@@ -24,7 +25,8 @@ public LogicalPlan apply(LogicalPlan logicalPlan) {
2425
sampleProbability.set(randomSample.probability());
2526
}
2627
if (plan instanceof Aggregate && sampleProbability.get() != null) {
27-
plan = plan.transformExpressionsOnly(AggregateFunction.class, af -> af.correctForSampling(sampleProbability.get()));
28+
plan = plan.transformExpressionsOnly(
29+
e -> e instanceof HasSampleCorrection hsc ? hsc.sampleCorrection(sampleProbability.get()) : e);
2830
sampleProbability.set(null);
2931
}
3032
return plan;

0 commit comments

Comments
 (0)