Skip to content

Commit a1518d0

Browse files
committed
Move sample correction to approximate class
1 parent 742803c commit a1518d0

File tree

4 files changed

+22
-42
lines changed

4 files changed

+22
-42
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
3131
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
3232
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
33+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
3334
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
3435
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
3536
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains;
3637
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
3738
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
3839
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
40+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3941
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
4042
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
4143
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
@@ -131,6 +133,11 @@ public interface LogicalPlanRunner {
131133
WeightedAvg.class
132134
);
133135

136+
private static final Set<Class<? extends AggregateFunction>> SAMPLE_CORRECTED_AGGS = Set.of(
137+
Count.class,
138+
Sum.class
139+
);
140+
134141
private static final Set<Class<? extends AggregateFunction>> SUPPORTED_MULTI_VALUED_AGGS = Set.of(
135142
org.elasticsearch.xpack.esql.expression.function.aggregate.Sample.class
136143
);
@@ -362,11 +369,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
362369
for (int bucketId = -1; bucketId < bucketCount; bucketId++) {
363370
AggregateFunction bucketedAgg = agg.withFilter(
364371
new MvContains(Source.EMPTY, bucketIdField.toAttribute(), Literal.integer(Source.EMPTY, bucketId)));
365-
Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
366-
? nsc.sampleCorrection(
367-
Literal.fromDouble(Source.EMPTY, bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT)
368-
)
369-
: bucketedAgg;
372+
Expression correctedAgg = correctForSampling(bucketedAgg, bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT);
370373
Alias correctedAggAlias = bucketId == -1
371374
? aggAlias.replaceChild(correctedAgg)
372375
: new Alias(
@@ -477,4 +480,16 @@ private LogicalPlan approximatePlan(double sampleProbability) {
477480
approximatePlan.setPreOptimized();
478481
return approximatePlan;
479482
}
483+
484+
private static Expression correctForSampling(AggregateFunction agg, double sampleProbability) {
485+
if (SAMPLE_CORRECTED_AGGS.contains(agg.getClass()) == false) {
486+
return agg;
487+
}
488+
Expression correctedAgg = new Div(agg.source(), agg, Literal.fromDouble(Source.EMPTY, sampleProbability));
489+
return switch (agg.dataType()) {
490+
case DOUBLE -> correctedAgg;
491+
case LONG -> new ToLong(agg.source(), correctedAgg);
492+
default -> throw new IllegalStateException("unexpected data type [" + agg.dataType() + "]");
493+
};
494+
}
480495
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/NeedsSampleCorrection.java

Lines changed: 0 additions & 14 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
1313
import org.elasticsearch.compute.aggregation.CountAggregatorFunction;
1414
import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder;
15-
import org.elasticsearch.xpack.esql.approximate.NeedsSampleCorrection;
1615
import org.elasticsearch.xpack.esql.core.expression.Expression;
1716
import org.elasticsearch.xpack.esql.core.expression.Literal;
1817
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -26,10 +25,8 @@
2625
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
2726
import org.elasticsearch.xpack.esql.expression.function.Param;
2827
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
29-
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
3028
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
3129
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
32-
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3330
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3431
import org.elasticsearch.xpack.esql.planner.ToAggregator;
3532

@@ -40,7 +37,7 @@
4037
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
4138
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
4239

43-
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression, NeedsSampleCorrection {
40+
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
4441
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);
4542

4643
@FunctionInfo(
@@ -188,9 +185,4 @@ public Expression surrogate() {
188185

189186
return null;
190187
}
191-
192-
@Override
193-
public Expression sampleCorrection(Expression sampleProbability) {
194-
return new ToLong(source(), new Div(source(), new Count(source(), field(), filter()), sampleProbability));
195-
}
196188
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier;
1515
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
1616
import org.elasticsearch.compute.data.AggregateMetricDoubleBlockBuilder;
17-
import org.elasticsearch.xpack.esql.approximate.NeedsSampleCorrection;
1817
import org.elasticsearch.xpack.esql.core.expression.Expression;
1918
import org.elasticsearch.xpack.esql.core.expression.Literal;
2019
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
@@ -28,9 +27,7 @@
2827
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
2928
import org.elasticsearch.xpack.esql.expression.function.Param;
3029
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
31-
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
3230
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
33-
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3431
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3532

3633
import java.io.IOException;
@@ -46,7 +43,7 @@
4643
/**
4744
* Sum all values of a field in matching documents.
4845
*/
49-
public class Sum extends NumericAggregate implements SurrogateExpression, NeedsSampleCorrection {
46+
public class Sum extends NumericAggregate implements SurrogateExpression {
5047
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::readFrom);
5148

5249
private final Expression summationMode;
@@ -176,14 +173,4 @@ public Expression surrogate() {
176173
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
177174
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
178175
}
179-
180-
@Override
181-
public Expression sampleCorrection(Expression sampleProbability) {
182-
Expression correctedSum = new Div(source(), new Sum(source(), field(), filter(), summationMode()), sampleProbability);
183-
return switch (dataType()) {
184-
case DOUBLE -> correctedSum;
185-
case LONG -> new ToLong(source(), correctedSum);
186-
default -> throw new IllegalStateException("unexpected data type [" + dataType() + "]");
187-
};
188-
}
189176
}

0 commit comments

Comments
 (0)