Skip to content

Commit 3d3ca89

Browse files
committed
One column per bucket
1 parent 141d083 commit 3d3ca89

File tree

3 files changed

+105
-94
lines changed

3 files changed

+105
-94
lines changed

x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/random/RandomEvaluator.java

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 97 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,18 @@
1919
import org.elasticsearch.xpack.esql.core.expression.Literal;
2020
import org.elasticsearch.xpack.esql.core.expression.NameId;
2121
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
22-
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
2322
import org.elasticsearch.xpack.esql.core.tree.Source;
24-
import org.elasticsearch.xpack.esql.core.type.DataType;
2523
import org.elasticsearch.xpack.esql.core.util.Holder;
2624
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
2725
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
28-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
2926
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
3027
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
31-
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
3228
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
3329
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
30+
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains;
3431
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
3532
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
33+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
3634
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
3735
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3836
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
@@ -56,8 +54,10 @@
5654
import org.elasticsearch.xpack.esql.session.Result;
5755

5856
import java.util.ArrayList;
57+
import java.util.HashMap;
5958
import java.util.HashSet;
6059
import java.util.List;
60+
import java.util.Map;
6161
import java.util.Set;
6262
import java.util.stream.Collectors;
6363

@@ -120,7 +120,7 @@ public interface LogicalPlanRunner {
120120
// TODO: find a good default value, or alternative ways of setting it
121121
private static final int SAMPLE_ROW_COUNT = 100000;
122122

123-
private static final int BUCKET_COUNT = 25;
123+
private static final int BUCKET_COUNT = 3; // 25;
124124

125125
private static final Logger logger = LogManager.getLogger(Approximate.class);
126126

@@ -301,12 +301,11 @@ private LogicalPlan approximatePlan(double sampleProbability) {
301301

302302
logger.debug("generating approximate plan (p={})", sampleProbability);
303303
Holder<Boolean> encounteredStats = new Holder<>(false);
304-
Set<NameId> variablesWithConfidenceInterval = new HashSet<>();
305-
Set<NameId> variablesWithPastConfidenceInterval = new HashSet<>();
304+
Map<NameId, List<Alias>> variablesWithConfidenceInterval = new HashMap<>();
306305

307-
Alias bucketId = new Alias(
306+
Alias bucketIdField = new Alias(
308307
Source.EMPTY,
309-
".bucket_id",
308+
"$$bucket_id",
310309
new MvAppend(
311310
Source.EMPTY,
312311
Literal.integer(Source.EMPTY, -1),
@@ -316,125 +315,134 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316315

317316
LogicalPlan approximatePlan = logicalPlan.transformUp(plan -> {
318317
if (plan instanceof LeafPlan) {
319-
return new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), plan);
318+
plan = new Sample(Source.EMPTY, Literal.fromDouble(Source.EMPTY, sampleProbability), plan);
320319
} else if (encounteredStats.get() == false && plan instanceof Aggregate aggregate) {
321320
encounteredStats.set(true);
322321

323-
Eval addBucketId = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketId));
324-
List<NamedExpression> aggregates = new ArrayList<>(aggregate.aggregates());
325-
aggregates.add(bucketId.toAttribute());
326-
List<Expression> groupings = new ArrayList<>(aggregate.groupings());
327-
groupings.add(bucketId.toAttribute());
328-
329-
Aggregate aggregateWithBucketId = (Aggregate) aggregate.with(addBucketId, groupings, aggregates)
330-
.transformExpressionsOnlyUp(
331-
expr -> expr instanceof NeedsSampleCorrection nsc ? nsc.sampleCorrection(
332-
new Case(Source.EMPTY,
333-
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
334-
List.of(
335-
Literal.fromDouble(Source.EMPTY, sampleProbability),
336-
Literal.fromDouble(Source.EMPTY, sampleProbability / BUCKET_COUNT)
337-
)
338-
)
339-
) : expr);
340-
341-
for (NamedExpression aggr : aggregate.aggregates()) {
342-
if (aggr instanceof Alias alias && alias.child() instanceof AggregateFunction) {
343-
variablesWithConfidenceInterval.add(alias.id());
322+
Eval addBucketId = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketIdField));
323+
List<NamedExpression> aggregates = new ArrayList<>();
324+
for (NamedExpression aggOrKey : aggregate.aggregates()) {
325+
if ((aggOrKey instanceof Alias alias && alias.child() instanceof AggregateFunction) == false) {
326+
// This is a grouping key, not an aggregate function.
327+
aggregates.add(aggOrKey);
328+
continue;
344329
}
330+
Alias aggAlias = (Alias) aggOrKey;
331+
AggregateFunction agg = (AggregateFunction) aggAlias.child();
332+
List<Alias> bucketedAggs = new ArrayList<>();
333+
for (int bucketId = -1; bucketId < BUCKET_COUNT; bucketId++) {
334+
AggregateFunction bucketedAgg = agg.withFilter(
335+
new MvContains(Source.EMPTY, bucketIdField.toAttribute(), Literal.integer(Source.EMPTY, bucketId)));
336+
Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
337+
? nsc.sampleCorrection(
338+
Literal.fromDouble(Source.EMPTY, bucketId == -1 ? sampleProbability : sampleProbability / BUCKET_COUNT)
339+
)
340+
: bucketedAgg;
341+
Alias correctAggAlias = bucketId == -1
342+
? aggAlias.replaceChild(correctedAgg)
343+
: new Alias(
344+
Source.EMPTY,
345+
aggOrKey.name() + "$bucket:" + bucketId,
346+
correctedAgg
347+
);
348+
aggregates.add(correctAggAlias);
349+
if (bucketId >= 0) {
350+
bucketedAggs.add(correctAggAlias);
351+
}
352+
}
353+
variablesWithConfidenceInterval.put(aggOrKey.id(), bucketedAggs);
345354
}
355+
plan = aggregate.with(addBucketId, aggregate.groupings(), aggregates);
346356

347-
return aggregateWithBucketId;
348357
} else if (encounteredStats.get()) {
349358
System.out.println("@@@ UPDATE variablesWithConfidenceInterval");
350359
System.out.println("plan = " + plan);
351-
System.out.println("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval);
360+
System.out.println("vars = " + variablesWithConfidenceInterval);
352361
switch (plan) {
353362
case Eval eval:
363+
List<Alias> newFields = new ArrayList<>(eval.fields());
354364
for (Alias field : eval.fields()) {
355-
if (field.anyMatch(expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval.contains(named.id()))) {
356-
// TODO: blacklist / whitelist?
357-
if (field.child() instanceof MvAppend == false && field.dataType().isNumeric()) {
358-
variablesWithConfidenceInterval.add(field.id());
359-
} else {
360-
variablesWithPastConfidenceInterval.add(field.id());
365+
if (field.dataType().isNumeric() == false || field.child().anyMatch(expr -> expr instanceof MvAppend)) {
366+
continue;
367+
}
368+
if (field.child().anyMatch(expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval.containsKey(named.id()))) {
369+
List<Alias> newBuckets = new ArrayList<>();
370+
for (int bucketId = 0; bucketId < BUCKET_COUNT; bucketId++) {
371+
final int finalBucketId = bucketId;
372+
Expression newChild = field.child().transformDown(expr -> {
373+
if (expr instanceof NamedExpression named && variablesWithConfidenceInterval.containsKey(named.id())) {
374+
List<Alias> buckets = variablesWithConfidenceInterval.get(named.id());
375+
return buckets.get(finalBucketId).toAttribute();
376+
} else {
377+
return expr;
378+
}
379+
});
380+
Alias newField = new Alias(
381+
Source.EMPTY,
382+
field.name() + "$bucket:" + bucketId,
383+
newChild
384+
);
385+
newBuckets.add(newField);
361386
}
362-
} else if (field.anyMatch(expr -> expr instanceof NamedExpression named && variablesWithPastConfidenceInterval.contains(named.id()))) {
363-
variablesWithPastConfidenceInterval.add(field.id());
387+
variablesWithConfidenceInterval.put(field.id(), newBuckets);
388+
newFields.addAll(newBuckets);
364389
}
365390
}
391+
plan = new Eval(Source.EMPTY, eval.child(), newFields);
366392
break;
367-
case Project project:
368-
List<NamedExpression> projections = new ArrayList<>(project.projections());
369-
projections.add(bucketId.toAttribute());
370-
plan = project.withProjections(projections);
371-
break;
393+
// case Project project:
394+
// List<NamedExpression> projections = new ArrayList<>(project.projections());
395+
// plan = project.withProjections(projections);
396+
// break;
372397
case Rename rename:
373398
// TODO
374399
break;
375400
default:
376401
}
377-
System.out.println("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval);
402+
System.out.println("vars = " + variablesWithConfidenceInterval);
378403
}
379404
return plan;
380405
});
381406

382407
System.out.println("### OUTPUT: " + approximatePlan.output());
383408

384-
List<NamedExpression> aggregates = new ArrayList<>();
385-
List<Expression> groupings = new ArrayList<>();
386-
for (Attribute attribute : approximatePlan.output()) {
387-
if (attribute.id() == bucketId.id()) {
388-
continue;
389-
}
390-
if (variablesWithConfidenceInterval.contains(attribute.id()) || variablesWithPastConfidenceInterval.contains(attribute.id())) {
391-
Alias bestEstimate = new Alias(
409+
List<Alias> confidenceIntervals = new ArrayList<>();
410+
for (Attribute output : logicalPlan.output()) {
411+
if (variablesWithConfidenceInterval.containsKey(output.id())) {
412+
List<Alias> buckets = variablesWithConfidenceInterval.get(output.id());
413+
Expression appendedBuckets = buckets.getFirst().toAttribute();
414+
for (int i = 1; i < buckets.size(); i++) {
415+
appendedBuckets = new MvAppend(Source.EMPTY, appendedBuckets, buckets.get(i).toAttribute());
416+
}
417+
confidenceIntervals.add(new Alias(
392418
Source.EMPTY,
393-
attribute.name(),
394-
new Values(
395-
Source.EMPTY,
396-
attribute,
397-
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1))
398-
)
399-
);
400-
aggregates.add(bestEstimate);
401-
if (variablesWithConfidenceInterval.contains(attribute.id())) {
402-
aggregates.add(new Alias(
403-
Source.EMPTY, "CONFIDENCE_INTERVAL(" + attribute.name() + ")", new ConfidenceInterval(
419+
"CONFIDENCE_INTERVAL(" + output.name() + ")",
420+
new ConfidenceInterval(
404421
Source.EMPTY,
405-
bestEstimate.toAttribute(),
406-
new Top(
407-
Source.EMPTY,
408-
attribute,
409-
new NotEquals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1)),
410-
Literal.integer(Source.EMPTY, BUCKET_COUNT),
411-
Literal.keyword(Source.EMPTY, "ASC")
412-
),
422+
output,
423+
appendedBuckets,
413424
Literal.integer(Source.EMPTY, BUCKET_COUNT),
414425
Literal.fromDouble(Source.EMPTY, 0.0)
415-
// TODO: fix, 0.0 or NaN ?? TODO: remove!!
416426
)
417-
));
418-
}
419-
} else {
420-
aggregates.add(attribute);
421-
groupings.add(attribute);
427+
));
422428
}
423429
}
424430

425-
Aggregate finalAggregate = new Aggregate(
431+
approximatePlan = new Eval(
426432
Source.EMPTY,
427433
approximatePlan,
428-
groupings,
429-
aggregates
434+
confidenceIntervals
430435
);
431436

432-
if (approximatePlan instanceof Limit || approximatePlan instanceof TopN) {
433-
approximatePlan = ((UnaryPlan) approximatePlan).replaceChild(finalAggregate.replaceChild(((UnaryPlan) approximatePlan).child()));
434-
} else {
435-
// Can this happen? Or is the last command always a Limit / TopN?
436-
approximatePlan = finalAggregate;
437-
}
437+
Set<Attribute> dropAttributes = variablesWithConfidenceInterval.values().stream().flatMap(List::stream).map(Alias::toAttribute).collect(Collectors.toSet());
438+
List<Attribute> keepAttributes = new ArrayList<>(approximatePlan.output());
439+
keepAttributes.removeAll(dropAttributes);
440+
441+
approximatePlan = new Project(
442+
Source.EMPTY,
443+
approximatePlan,
444+
keepAttributes
445+
);
438446

439447
logger.info("### AFTER APPROXIMATE:\n{}", approximatePlan);
440448
System.out.println("### OUTPUT: " + approximatePlan.output());

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/ConfidenceInterval.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.io.stream.StreamInput;
1616
import org.elasticsearch.common.io.stream.StreamOutput;
1717
import org.elasticsearch.compute.ann.Evaluator;
18+
import org.elasticsearch.compute.ann.Position;
1819
import org.elasticsearch.compute.data.DoubleBlock;
1920
import org.elasticsearch.compute.data.IntBlock;
2021
import org.elasticsearch.compute.data.LongBlock;
@@ -144,7 +145,7 @@ public boolean equals(Object obj) {
144145
}
145146

146147
@Evaluator(extraName = "Double")
147-
static void process(DoubleBlock.Builder builder, int position, DoubleBlock bestEstimateBlock, DoubleBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
148+
static void process(DoubleBlock.Builder builder, @Position int position, DoubleBlock bestEstimateBlock, DoubleBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
148149
assert bestEstimateBlock.getValueCount(position) == 1 : "bestEstimate: expected 1 element, got " + bestEstimateBlock.getValueCount(position);
149150
assert bucketCountBlock.getValueCount(position) == 1 : "bucketCount: expected 1 element, got " + bucketCountBlock.getValueCount(position);
150151
assert emptyBucketValueBlock.getValueCount(position) == 1 : "emptyBucketValue: expected 1 element, got " + emptyBucketValueBlock.getValueCount(position);
@@ -168,7 +169,7 @@ static void process(DoubleBlock.Builder builder, int position, DoubleBlock bestE
168169
}
169170

170171
@Evaluator(extraName = "Int")
171-
static void process(IntBlock.Builder builder, int position, IntBlock bestEstimateBlock, IntBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
172+
static void process(IntBlock.Builder builder, @Position int position, IntBlock bestEstimateBlock, IntBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
172173
assert bestEstimateBlock.getValueCount(position) == 1 : "bestEstimate: expected 1 element, got " + bestEstimateBlock.getValueCount(position);
173174
assert bucketCountBlock.getValueCount(position) == 1 : "bucketCount: expected 1 element, got " + bucketCountBlock.getValueCount(position);
174175
assert emptyBucketValueBlock.getValueCount(position) == 1 : "emptyBucketValue: expected 1 element, got " + emptyBucketValueBlock.getValueCount(position);
@@ -190,7 +191,7 @@ static void process(IntBlock.Builder builder, int position, IntBlock bestEstimat
190191
}
191192

192193
@Evaluator(extraName = "Long")
193-
static void process(LongBlock.Builder builder, int position, LongBlock bestEstimateBlock, LongBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
194+
static void process(LongBlock.Builder builder, @Position int position, LongBlock bestEstimateBlock, LongBlock estimatesBlock, IntBlock bucketCountBlock, DoubleBlock emptyBucketValueBlock) {
194195
assert bestEstimateBlock.getValueCount(position) == 1 : "bestEstimate: expected 1 element, got " + bestEstimateBlock.getValueCount(position);
195196
assert bucketCountBlock.getValueCount(position) == 1 : "bucketCount: expected 1 element, got " + bucketCountBlock.getValueCount(position);
196197
assert emptyBucketValueBlock.getValueCount(position) == 1 : "emptyBucketValue: expected 1 element, got " + emptyBucketValueBlock.getValueCount(position);

0 commit comments

Comments
 (0)