Skip to content

Commit 1ee03c0

Browse files
committed
move final bucketId agg to the end
1 parent be8357e commit 1ee03c0

File tree

1 file changed

+97
-53
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate

1 file changed

+97
-53
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import org.elasticsearch.xpack.esql.VerificationException;
1515
import org.elasticsearch.xpack.esql.common.Failure;
1616
import org.elasticsearch.xpack.esql.core.expression.Alias;
17+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1718
import org.elasticsearch.xpack.esql.core.expression.Expression;
1819
import org.elasticsearch.xpack.esql.core.expression.Literal;
20+
import org.elasticsearch.xpack.esql.core.expression.NameId;
1921
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
2022
import org.elasticsearch.xpack.esql.core.tree.Source;
2123
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -40,18 +42,22 @@
4042
import org.elasticsearch.xpack.esql.plan.logical.Insist;
4143
import org.elasticsearch.xpack.esql.plan.logical.Keep;
4244
import org.elasticsearch.xpack.esql.plan.logical.LeafPlan;
45+
import org.elasticsearch.xpack.esql.plan.logical.Limit;
4346
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
4447
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
4548
import org.elasticsearch.xpack.esql.plan.logical.Project;
4649
import org.elasticsearch.xpack.esql.plan.logical.Rename;
4750
import org.elasticsearch.xpack.esql.plan.logical.Sample;
51+
import org.elasticsearch.xpack.esql.plan.logical.TopN;
4852
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
4953
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
5054
import org.elasticsearch.xpack.esql.session.Result;
5155

5256
import java.util.ArrayList;
57+
import java.util.HashSet;
5358
import java.util.List;
5459
import java.util.Set;
60+
import java.util.stream.Collectors;
5561

5662
/**
5763
* This class computes approximate and fast results for certain classes of
@@ -293,32 +299,30 @@ private LogicalPlan approximatePlan(double sampleProbability) {
293299

294300
logger.debug("generating approximate plan (p={})", sampleProbability);
295301
Holder<Boolean> encounteredStats = new Holder<>(false);
302+
Set<NameId> variablesWithConfidenceInterval = new HashSet<>();
303+
304+
Alias bucketId = new Alias(
305+
Source.EMPTY,
306+
".bucket_id",
307+
new MvAppend(
308+
Source.EMPTY,
309+
Literal.integer(Source.EMPTY, -1),
310+
new Random(Source.EMPTY, Literal.integer(Source.EMPTY, BUCKET_COUNT))
311+
)
312+
);
313+
296314
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
297315
if (plan instanceof LeafPlan) {
298316
return new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), plan);
299317
} else if (encounteredStats.get() == false && plan instanceof Aggregate aggregate) {
300318
encounteredStats.set(true);
301-
Alias bucketId = new Alias(
302-
Source.EMPTY,
303-
".bucket_id",
304-
new MvAppend(
305-
Source.EMPTY,
306-
Literal.integer(Source.EMPTY, -1),
307-
new Random(Source.EMPTY, Literal.integer(Source.EMPTY, BUCKET_COUNT))
308-
)
309-
);
319+
310320
Eval addBucketId = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketId));
311-
List<NamedExpression> aggregates = new ArrayList<>();
312-
for (NamedExpression aggr : aggregate.aggregates()) {
313-
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
314-
aggregates.add(new Alias(Source.EMPTY, ".bucketed-" + alias.name(), alias.child()));
315-
} else {
316-
aggregates.add(aggr);
317-
}
318-
}
321+
List<NamedExpression> aggregates = new ArrayList<>(aggregate.aggregates());
322+
aggregates.add(bucketId.toAttribute());
319323
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
320324
groupings.add(bucketId.toAttribute());
321-
aggregates.add(bucketId.toAttribute());
325+
322326
Aggregate aggregateWithBucketId = (Aggregate) aggregate.with(addBucketId, groupings, aggregates)
323327
.transformExpressionsOnlyUp(
324328
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(
@@ -330,48 +334,88 @@ private LogicalPlan approximatePlan(double sampleProbability) {
330334
)
331335
)
332336
) : expr);
333-
aggregates = new ArrayList<>();
334-
for (int i = 0; i < aggregate.aggregates().size(); i++) {
335-
NamedExpression aggr = aggregate.aggregates().get(i);
336-
NamedExpression sampledAggr = aggregateWithBucketId.aggregates().get(i);
337-
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction aggFn) {
338-
// TODO: probably filter low non-empty bucket counts. They're inaccurate and for skew, you need >=3.
339-
aggregates.add(
340-
alias.replaceChild(
341-
new ConfidenceInterval( // TODO: move confidence level to the end
342-
Source.EMPTY,
343-
new Min(
344-
Source.EMPTY,
345-
sampledAggr.toAttribute(),
346-
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1))
347-
),
348-
new Top(
349-
Source.EMPTY,
350-
sampledAggr.toAttribute(),
351-
new NotEquals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
352-
Literal.integer(Source.EMPTY, BUCKET_COUNT),
353-
Literal.keyword(Source.EMPTY, "ASC")
354-
),
355-
Literal.integer(Source.EMPTY, BUCKET_COUNT),
356-
Literal.fromDouble(Source.EMPTY, aggFn instanceof NeedsSampleCorrection ? 0.0 : Double.NaN)
357-
)
358-
)
359-
);
360-
} else {
361-
aggregates.add(aggr);
337+
338+
for (NamedExpression aggr : aggregate.aggregates()) {
339+
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
340+
variablesWithConfidenceInterval.add(alias.id());
362341
}
363342
}
364-
plan = new Aggregate(
365-
Source.EMPTY,
366-
aggregateWithBucketId,
367-
aggregate.groupings().stream().map(e -> e instanceof Alias a ? a.toAttribute() : e).toList(),
368-
aggregates
369-
);
343+
344+
return aggregateWithBucketId;
345+
} else if (encounteredStats.get()) {
346+
System.out.println("@@@ UPDATE variablesWithConfidenceInterval");
347+
System.out.println("plan = " + plan);
348+
System.out.println("vars = " + variablesWithConfidenceInterval);
349+
switch (plan) {
350+
case Eval eval:
351+
for (NamedExpression field : eval.fields()) {
352+
if (field.anyMatch(expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval.contains(named.id()))) {
353+
variablesWithConfidenceInterval.add(field.id());
354+
}
355+
}
356+
break;
357+
case Rename rename:
358+
// TODO
359+
break;
360+
default:
361+
}
362+
System.out.println("vars = " + variablesWithConfidenceInterval);
370363
}
371364
return plan;
372365
});
373366

367+
System.out.println("### OUTPUT: " + approximatePlan.output());
368+
369+
List<NamedExpression> aggregates = new ArrayList<>();
370+
List<Expression> groupings = new ArrayList<>();
371+
for (Attribute attribute : approximatePlan.output()) {
372+
if (attribute.id() == bucketId.id()) {
373+
continue;
374+
}
375+
if (variablesWithConfidenceInterval.contains(attribute.id())) {
376+
aggregates.add(new Alias(
377+
Source.EMPTY,
378+
attribute.name(),
379+
new ConfidenceInterval(
380+
Source.EMPTY,
381+
new Min(
382+
Source.EMPTY,
383+
attribute,
384+
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1))
385+
),
386+
new Top(
387+
Source.EMPTY,
388+
attribute,
389+
new NotEquals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
390+
Literal.integer(Source.EMPTY, BUCKET_COUNT),
391+
Literal.keyword(Source.EMPTY, "ASC")
392+
),
393+
Literal.integer(Source.EMPTY, BUCKET_COUNT),
394+
Literal.fromDouble(Source.EMPTY, 0.0) // TODO: fix, 0.0 or NaN ??
395+
)
396+
));
397+
} else {
398+
aggregates.add(attribute);
399+
groupings.add(attribute);
400+
}
401+
}
402+
403+
Aggregate finalAggregate = new Aggregate(
404+
Source.EMPTY,
405+
approximatePlan,
406+
groupings,
407+
aggregates
408+
);
409+
410+
if (approximatePlan instanceof Limit || approximatePlan instanceof TopN) {
411+
approximatePlan = ((UnaryPlan) approximatePlan).replaceChild(finalAggregate.replaceChild(((UnaryPlan) approximatePlan).child()));
412+
} else {
413+
// Can this happen? Or is the last command always a Limit / TopN?
414+
approximatePlan = finalAggregate;
415+
}
416+
374417
logger.info("### AFTER APPROXIMATE:\n{}", approximatePlan);
418+
System.out.println("### OUTPUT: " + approximatePlan.output());
375419

376420
approximatePlan.setPreOptimized();
377421
return approximatePlan;

0 commit comments

Comments
 (0)