|
30 | 30 | import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
|
31 | 31 | import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
|
32 | 32 | import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
|
| 33 | +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; |
33 | 34 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
|
34 | 35 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
|
35 | 36 | import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains;
|
36 | 37 | import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
|
37 | 38 | import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
|
38 | 39 | import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
|
| 40 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; |
39 | 41 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
|
40 | 42 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
41 | 43 | import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
|
@@ -131,6 +133,11 @@ public interface LogicalPlanRunner {
|
131 | 133 | WeightedAvg.class
|
132 | 134 | );
|
133 | 135 |
|
| 136 | + private static final Set<Class<? extends AggregateFunction>> SAMPLE_CORRECTED_AGGS = Set.of( |
| 137 | + Count.class, |
| 138 | + Sum.class |
| 139 | + ); |
| 140 | + |
134 | 141 | private static final Set<Class<? extends AggregateFunction>> SUPPORTED_MULTI_VALUED_AGGS = Set.of(
|
135 | 142 | org.elasticsearch.xpack.esql.expression.function.aggregate.Sample.class
|
136 | 143 | );
|
@@ -362,11 +369,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
|
362 | 369 | for (int bucketId = -1; bucketId < bucketCount; bucketId++) {
|
363 | 370 | AggregateFunction bucketedAgg = agg.withFilter(
|
364 | 371 | new MvContains(Source.EMPTY, bucketIdField.toAttribute(), Literal.integer(Source.EMPTY, bucketId)));
|
365 |
| - Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc |
366 |
| - ? nsc.sampleCorrection( |
367 |
| - Literal.fromDouble(Source.EMPTY, bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT) |
368 |
| - ) |
369 |
| - : bucketedAgg; |
| 372 | + Expression correctedAgg = correctForSampling(bucketedAgg, bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT); |
370 | 373 | Alias correctedAggAlias = bucketId == -1
|
371 | 374 | ? aggAlias.replaceChild(correctedAgg)
|
372 | 375 | : new Alias(
|
@@ -477,4 +480,16 @@ private LogicalPlan approximatePlan(double sampleProbability) {
|
477 | 480 | approximatePlan.setPreOptimized();
|
478 | 481 | return approximatePlan;
|
479 | 482 | }
|
| 483 | + |
| 484 | + private static Expression correctForSampling(AggregateFunction agg, double sampleProbability) { |
| 485 | + if (SAMPLE_CORRECTED_AGGS.contains(agg.getClass()) == false) { |
| 486 | + return agg; |
| 487 | + } |
| 488 | + Expression correctedAgg = new Div(agg.source(), agg, Literal.fromDouble(Source.EMPTY, sampleProbability)); |
| 489 | + return switch (agg.dataType()) { |
| 490 | + case DOUBLE -> correctedAgg; |
| 491 | + case LONG -> new ToLong(agg.source(), correctedAgg); |
| 492 | + default -> throw new IllegalStateException("unexpected data type [" + agg.dataType() + "]"); |
| 493 | + }; |
| 494 | + } |
480 | 495 | }
|
0 commit comments