Skip to content

Commit 742803c

Browse files
committed
whitelist agg functions
1 parent 9c27eb1 commit 742803c

File tree

1 file changed

+47
-9
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate

1 file changed

+47
-9
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,21 @@
2222
import org.elasticsearch.xpack.esql.core.tree.Source;
2323
import org.elasticsearch.xpack.esql.core.util.Holder;
2424
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
25+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg;
2526
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
27+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Median;
28+
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
29+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
30+
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
31+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
32+
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
2633
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
2734
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
2835
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains;
2936
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;
3037
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
3138
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
39+
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
3240
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3341
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
3442
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
@@ -52,6 +60,7 @@
5260
import java.util.ArrayList;
5361
import java.util.HashMap;
5462
import java.util.List;
63+
import java.util.Locale;
5564
import java.util.Map;
5665
import java.util.Set;
5766
import java.util.stream.Collectors;
@@ -93,7 +102,6 @@ public interface LogicalPlanRunner {
93102
void run(LogicalPlan plan, ActionListener<Result> listener);
94103
}
95104

96-
97105
/**
98106
* These commands preserve all rows, making it easy to predict the number of output rows.
99107
*/
@@ -112,10 +120,25 @@ public interface LogicalPlanRunner {
112120
Rename.class
113121
);
114122

123+
private static final Set<Class<? extends AggregateFunction>> SUPPORTED_SINGLE_VALUED_AGGS = Set.of(
124+
Avg.class,
125+
Count.class,
126+
Median.class,
127+
MedianAbsoluteDeviation.class,
128+
Percentile.class,
129+
StdDev.class,
130+
Sum.class,
131+
WeightedAvg.class
132+
);
133+
134+
private static final Set<Class<? extends AggregateFunction>> SUPPORTED_MULTI_VALUED_AGGS = Set.of(
135+
org.elasticsearch.xpack.esql.expression.function.aggregate.Sample.class
136+
);
137+
115138
// TODO: find a good default value, or alternative ways of setting it
116139
private static final int SAMPLE_ROW_COUNT = 100000;
117140

118-
private static final int BUCKET_COUNT = 3;
141+
private static final int BUCKET_COUNT = 16;
119142

120143
private static final Logger logger = LogManager.getLogger(Approximate.class);
121144

@@ -153,7 +176,7 @@ private boolean verifyPlan() {
153176
logicalPlan.forEachUp(plan -> {
154177
if (plan instanceof LeafPlan == false && plan instanceof UnaryPlan == false) {
155178
throw new VerificationException(
156-
List.of(Failure.fail(plan, "query with [" + plan.nodeName() + "] cannot be approximated"))
179+
List.of(Failure.fail(plan, "query with [" + plan.nodeName().toUpperCase(Locale.ROOT) + "] cannot be approximated"))
157180
);
158181
}
159182
});
@@ -162,8 +185,16 @@ private boolean verifyPlan() {
162185
Holder<Boolean> hasFilters = new Holder<>(false);
163186
logicalPlan.transformUp(plan -> {
164187
if (encounteredStats.get() == false) {
165-
if (plan instanceof Aggregate) {
188+
if (plan instanceof Aggregate aggregate) {
166189
encounteredStats.set(true);
190+
plan.transformExpressionsOnly(AggregateFunction.class, aggFn -> {
191+
if (SUPPORTED_SINGLE_VALUED_AGGS.contains(aggFn.getClass()) == false && SUPPORTED_MULTI_VALUED_AGGS.contains(aggFn.getClass()) == false) {
192+
throw new VerificationException(
193+
List.of(Failure.fail(aggFn, "aggregation function [" + aggFn.nodeName().toUpperCase() + "] cannot be approximated"))
194+
);
195+
}
196+
return aggFn;
197+
});
167198
} else if (ROW_PRESERVING_COMMANDS.contains(plan.getClass()) == false) {
168199
hasFilters.set(true);
169200
}
@@ -316,7 +347,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316347

317348
Eval addBucketId = new Eval(Source.EMPTY, aggregate.child(), List.of(bucketIdField));
318349
List<NamedExpression> aggregates = new ArrayList<>();
319-
Expression allBucketsNonNull = Literal.TRUE;
350+
Expression allBucketsNonEmpty = Literal.TRUE;
320351
for (NamedExpression aggOrKey : aggregate.aggregates()) {
321352
if ((aggOrKey instanceof Alias alias && alias.child() instanceof AggregateFunction) == false) {
322353
// This is a grouping key, not an aggregate function.
@@ -325,8 +356,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
325356
}
326357
Alias aggAlias = (Alias) aggOrKey;
327358
AggregateFunction agg = (AggregateFunction) aggAlias.child();
359+
boolean isMultiValued = SUPPORTED_MULTI_VALUED_AGGS.contains(agg.getClass());
360+
int bucketCount = isMultiValued ? 0 : BUCKET_COUNT;
328361
List<Alias> bucketedAggs = new ArrayList<>();
329-
for (int bucketId = -1; bucketId < BUCKET_COUNT; bucketId++) {
362+
for (int bucketId = -1; bucketId < bucketCount; bucketId++) {
330363
AggregateFunction bucketedAgg = agg.withFilter(
331364
new MvContains(Source.EMPTY, bucketIdField.toAttribute(), Literal.integer(Source.EMPTY, bucketId)));
332365
Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
@@ -345,12 +378,17 @@ private LogicalPlan approximatePlan(double sampleProbability) {
345378
if (bucketId >= 0) {
346379
bucketedAggs.add(correctedAggAlias);
347380
}
348-
allBucketsNonNull = new And(Source.EMPTY, allBucketsNonNull, new IsNotNull(Source.EMPTY, correctedAggAlias.toAttribute()));
381+
allBucketsNonEmpty = new And(Source.EMPTY, allBucketsNonEmpty,
382+
agg instanceof Count
383+
? new NotEquals(Source.EMPTY, correctedAggAlias.toAttribute(), Literal.integer(Source.EMPTY, 0))
384+
: new IsNotNull(Source.EMPTY, correctedAggAlias.toAttribute()));
385+
}
386+
if (isMultiValued == false) {
387+
variablesWithConfidenceInterval.put(aggOrKey.id(), bucketedAggs);
349388
}
350-
variablesWithConfidenceInterval.put(aggOrKey.id(), bucketedAggs);
351389
}
352390
plan = aggregate.with(addBucketId, aggregate.groupings(), aggregates);
353-
plan = new Filter(Source.EMPTY, plan, allBucketsNonNull);
391+
plan = new Filter(Source.EMPTY, plan, allBucketsNonEmpty);
354392

355393
} else if (encounteredStats.get()) {
356394
switch (plan) {

0 commit comments

Comments
 (0)