Skip to content

Commit cc104ef

Browse files
committed
correct stats for bucketing
1 parent 781992f commit cc104ef

File tree

1 file changed

+11
-10
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate

1 file changed

+11
-10
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
2525
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
2626
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
27+
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
2728
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
2829
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
2930
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
30-
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3131
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
3232
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
3333
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
@@ -121,6 +121,8 @@ public interface LogicalPlanRunner {
121121
// TODO: find a good default value, or alternative ways of setting it
122122
private static final int SAMPLE_ROW_COUNT = 100000;
123123

124+
private static final int BUCKET_COUNT = 25;
125+
124126
private static final Logger logger = LogManager.getLogger(Approximate.class);
125127

126128
private final LogicalPlan logicalPlan;
@@ -313,14 +315,15 @@ private LogicalPlan approximatePlan(double sampleProbability) {
313315
if (plan instanceof Aggregate aggregate) {
314316
encounteredStats.set(true);
315317
Expression sampleProbabilityExpr = new Literal(Source.EMPTY, sampleProbability, DataType.DOUBLE);
318+
Expression bucketProbabilityExpr = new Literal(Source.EMPTY, sampleProbability / BUCKET_COUNT, DataType.DOUBLE);
316319
Sample sample = new Sample(Source.EMPTY, sampleProbabilityExpr, aggregate.child());
317320
Alias sampleId = new Alias(
318321
Source.EMPTY,
319322
".sample_id",
320323
new MvAppend(
321324
Source.EMPTY,
322325
new Literal(Source.EMPTY, -1, DataType.INTEGER),
323-
new Random(Source.EMPTY, new Literal(Source.EMPTY, 25, DataType.INTEGER))
326+
new Random(Source.EMPTY, new Literal(Source.EMPTY, BUCKET_COUNT, DataType.INTEGER))
324327
)
325328
);
326329
Eval addSampleId = new Eval(Source.EMPTY, sample, List.of(sampleId));
@@ -337,8 +340,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
337340
aggregates.add(sampleId.toAttribute());
338341
Aggregate aggregateWithSampledId = (Aggregate) aggregate.with(addSampleId, groupings, aggregates)
339342
.transformExpressionsOnlyUp(
340-
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(sampleProbabilityExpr) : expr
341-
);
343+
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(
344+
new Case(Source.EMPTY,
345+
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
346+
List.of(sampleProbabilityExpr, bucketProbabilityExpr))) : expr);
342347
aggregates = new ArrayList<>();
343348
for (int i = 0; i < aggregate.aggregates().size(); i++) {
344349
NamedExpression aggr = aggregate.aggregates().get(i);
@@ -356,13 +361,9 @@ private LogicalPlan approximatePlan(double sampleProbability) {
356361
),
357362
new Top(
358363
Source.EMPTY,
359-
new Mul( // TODO: make this mul a sample correction 1/buckets
360-
Source.EMPTY,
361-
Literal.integer(Source.EMPTY, 25),
362-
sampledAggr.toAttribute()
363-
),
364+
sampledAggr.toAttribute(),
364365
new NotEquals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
365-
Literal.integer(Source.EMPTY, 25),
366+
Literal.integer(Source.EMPTY, BUCKET_COUNT),
366367
Literal.keyword(Source.EMPTY, "ASC")
367368
)
368369
)

0 commit comments

Comments
 (0)