Skip to content

Commit a911f59

Browse files
committed
query with confidence interval
1 parent 19da657 commit a911f59

File tree

1 file changed

+79
-14
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate

1 file changed

+79
-14
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@
2121
import org.elasticsearch.xpack.esql.core.tree.Source;
2222
import org.elasticsearch.xpack.esql.core.type.DataType;
2323
import org.elasticsearch.xpack.esql.core.util.Holder;
24+
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
2425
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
26+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
2527
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
28+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
29+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
2630
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
2731
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
2832
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
33+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
34+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
35+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
2936
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3037
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
3138
import org.elasticsearch.xpack.esql.plan.logical.Drop;
@@ -115,7 +122,7 @@ public interface LogicalPlanRunner {
115122
private static final Set<Class<? extends LogicalPlan>> INCOMPATIBLE_COMMANDS = Set.of(InlineStats.class, LookupJoin.class);
116123

117124
// TODO: find a good default value, or alternative ways of setting it
118-
private static final int SAMPLE_ROW_COUNT = 10000;
125+
private static final int SAMPLE_ROW_COUNT = 100000;
119126

120127
private static final Logger logger = LogManager.getLogger(Approximate.class);
121128

@@ -297,6 +304,9 @@ private LogicalPlan approximatePlan(double sampleProbability) {
297304
logger.debug("using original plan (too few rows)");
298305
return logicalPlan;
299306
}
307+
308+
logger.info("### BEFORE APPROXIMATE:\n{}", logicalPlan);
309+
300310
logger.debug("generating approximate plan (p={})", sampleProbability);
301311
Holder<Boolean> encounteredStats = new Holder<>(false);
302312
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
@@ -307,29 +317,84 @@ private LogicalPlan approximatePlan(double sampleProbability) {
307317
encounteredStats.set(true);
308318
Expression sampleProbabilityExpr = new Literal(Source.EMPTY, sampleProbability, DataType.DOUBLE);
309319
Sample sample = new Sample(Source.EMPTY, sampleProbabilityExpr, aggregate.child());
310-
Alias sampleId = new Alias(Source.EMPTY, ".sample_id",
311-
new MvAppend(Source.EMPTY, new Literal(Source.EMPTY, 0, DataType.INTEGER), new Random(Source.EMPTY, new Literal(Source.EMPTY, 25, DataType.INTEGER))));
312-
Eval addSampleId = new Eval(
320+
Alias sampleId = new Alias(
313321
Source.EMPTY,
314-
sample,
315-
List.of(sampleId)
316-
);
317-
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
318-
groupings.add(new ReferenceAttribute(Source.EMPTY, null, ".sample_id", DataType.INTEGER, sampleId.nullable(), sampleId.id(), sampleId.synthetic()));
319-
LogicalPlan aggregateWithSampledId = aggregate.with(addSampleId, groupings, aggregate.aggregates()).transformExpressionsOnlyUp(
320-
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(sampleProbabilityExpr) : expr
322+
".sample_id",
323+
new MvAppend(
324+
Source.EMPTY,
325+
new Literal(Source.EMPTY, -1, DataType.INTEGER),
326+
new Random(Source.EMPTY, new Literal(Source.EMPTY, 25, DataType.INTEGER))
327+
)
321328
);
329+
Eval addSampleId = new Eval(Source.EMPTY, sample, List.of(sampleId));
322330
List<NamedExpression> aggregates = new ArrayList<>();
323331
for (NamedExpression aggr : aggregate.aggregates()) {
324-
// aggregates.add(new Alias(Source.EMPTY, "confidence:" + aggr.name(),
325-
// new ConfidenceInterval(Source.EMPTY, new Top(aggr.))));
332+
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
333+
aggregates.add(new Alias(Source.EMPTY, ".sampled-" + alias.name(), alias.child()));
334+
} else {
335+
aggregates.add(aggr);
336+
}
326337
}
327-
plan = new Aggregate(Source.EMPTY, aggregateWithSampledId, aggregate.groupings(), aggregates);
338+
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
339+
groupings.add(sampleId.toAttribute());
340+
aggregates.add(sampleId.toAttribute());
341+
Aggregate aggregateWithSampledId = (Aggregate) aggregate.with(addSampleId, groupings, aggregates)
342+
.transformExpressionsOnlyUp(
343+
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(sampleProbabilityExpr) : expr
344+
);
345+
aggregates = new ArrayList<>();
346+
for (int i = 0; i < aggregate.aggregates().size(); i++) {
347+
NamedExpression aggr = aggregate.aggregates().get(i);
348+
NamedExpression sampledAggr = aggregateWithSampledId.aggregates().get(i);
349+
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
350+
aggregates.add(
351+
alias.replaceChild(
352+
new ToLong( // TODO: cast to original type
353+
Source.EMPTY,
354+
new ConfidenceInterval( // TODO: move confidence level to the end
355+
Source.EMPTY,
356+
new ToDouble(
357+
Source.EMPTY,
358+
new Min(
359+
Source.EMPTY,
360+
sampledAggr.toAttribute(),
361+
new Equals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1))
362+
)
363+
),
364+
new ToDouble(
365+
Source.EMPTY,
366+
new Top(
367+
Source.EMPTY,
368+
new Mul(
369+
Source.EMPTY,
370+
Literal.integer(Source.EMPTY, 25),
371+
sampledAggr.toAttribute()
372+
),
373+
new NotEquals(Source.EMPTY, sampleId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
374+
Literal.integer(Source.EMPTY, 25),
375+
Literal.keyword(Source.EMPTY, "ASC")
376+
)
377+
)
378+
)
379+
)
380+
)
381+
);
382+
} else {
383+
aggregates.add(aggr);
384+
}
385+
}
386+
plan = new Aggregate(
387+
Source.EMPTY,
388+
aggregateWithSampledId,
389+
aggregate.groupings().stream().map(e -> e instanceof Alias a ? a.toAttribute() : e).toList(),
390+
aggregates
391+
);
328392
}
329393
}
330394
return plan;
331395
});
332396

397+
logger.info("### AFTER APPROXIMATE:\n{}", approximatePlan);
333398

334399
approximatePlan.setPreOptimized();
335400
return approximatePlan;

0 commit comments

Comments
 (0)