22
22
import org .elasticsearch .xpack .esql .core .tree .Source ;
23
23
import org .elasticsearch .xpack .esql .core .util .Holder ;
24
24
import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
25
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Avg ;
25
26
import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
27
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Median ;
28
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .MedianAbsoluteDeviation ;
29
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Percentile ;
30
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .StdDev ;
31
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Sum ;
32
+ import org .elasticsearch .xpack .esql .expression .function .aggregate .WeightedAvg ;
26
33
import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
27
34
import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
28
35
import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvContains ;
29
36
import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
30
37
import org .elasticsearch .xpack .esql .expression .predicate .logical .And ;
31
38
import org .elasticsearch .xpack .esql .expression .predicate .nulls .IsNotNull ;
39
+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
32
40
import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
33
41
import org .elasticsearch .xpack .esql .plan .logical .ChangePoint ;
34
42
import org .elasticsearch .xpack .esql .plan .logical .Dissect ;
52
60
import java .util .ArrayList ;
53
61
import java .util .HashMap ;
54
62
import java .util .List ;
63
+ import java .util .Locale ;
55
64
import java .util .Map ;
56
65
import java .util .Set ;
57
66
import java .util .stream .Collectors ;
@@ -93,7 +102,6 @@ public interface LogicalPlanRunner {
93
102
void run (LogicalPlan plan , ActionListener <Result > listener );
94
103
}
95
104
96
-
97
105
/**
98
106
* These commands preserve all rows, making it easy to predict the number of output rows.
99
107
*/
@@ -112,10 +120,25 @@ public interface LogicalPlanRunner {
112
120
Rename .class
113
121
);
114
122
123
+ private static final Set <Class <? extends AggregateFunction >> SUPPORTED_SINGLE_VALUED_AGGS = Set .of (
124
+ Avg .class ,
125
+ Count .class ,
126
+ Median .class ,
127
+ MedianAbsoluteDeviation .class ,
128
+ Percentile .class ,
129
+ StdDev .class ,
130
+ Sum .class ,
131
+ WeightedAvg .class
132
+ );
133
+
134
+ private static final Set <Class <? extends AggregateFunction >> SUPPORTED_MULTI_VALUED_AGGS = Set .of (
135
+ org .elasticsearch .xpack .esql .expression .function .aggregate .Sample .class
136
+ );
137
+
115
138
// TODO: find a good default value, or alternative ways of setting it
116
139
private static final int SAMPLE_ROW_COUNT = 100000 ;
117
140
118
- private static final int BUCKET_COUNT = 3 ;
141
+ private static final int BUCKET_COUNT = 16 ;
119
142
120
143
private static final Logger logger = LogManager .getLogger (Approximate .class );
121
144
@@ -153,7 +176,7 @@ private boolean verifyPlan() {
153
176
logicalPlan .forEachUp (plan -> {
154
177
if (plan instanceof LeafPlan == false && plan instanceof UnaryPlan == false ) {
155
178
throw new VerificationException (
156
- List .of (Failure .fail (plan , "query with [" + plan .nodeName () + "] cannot be approximated" ))
179
+ List .of (Failure .fail (plan , "query with [" + plan .nodeName (). toUpperCase ( Locale . ROOT ) + "] cannot be approximated" ))
157
180
);
158
181
}
159
182
});
@@ -162,8 +185,16 @@ private boolean verifyPlan() {
162
185
Holder <Boolean > hasFilters = new Holder <>(false );
163
186
logicalPlan .transformUp (plan -> {
164
187
if (encounteredStats .get () == false ) {
165
- if (plan instanceof Aggregate ) {
188
+ if (plan instanceof Aggregate aggregate ) {
166
189
encounteredStats .set (true );
190
+ plan .transformExpressionsOnly (AggregateFunction .class , aggFn -> {
191
+ if (SUPPORTED_SINGLE_VALUED_AGGS .contains (aggFn .getClass ()) == false && SUPPORTED_MULTI_VALUED_AGGS .contains (aggFn .getClass ()) == false ) {
192
+ throw new VerificationException (
193
+ List .of (Failure .fail (aggFn , "aggregation function [" + aggFn .nodeName ().toUpperCase () + "] cannot be approximated" ))
194
+ );
195
+ }
196
+ return aggFn ;
197
+ });
167
198
} else if (ROW_PRESERVING_COMMANDS .contains (plan .getClass ()) == false ) {
168
199
hasFilters .set (true );
169
200
}
@@ -316,7 +347,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316
347
317
348
Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketIdField ));
318
349
List <NamedExpression > aggregates = new ArrayList <>();
319
- Expression allBucketsNonNull = Literal .TRUE ;
350
+ Expression allBucketsNonEmpty = Literal .TRUE ;
320
351
for (NamedExpression aggOrKey : aggregate .aggregates ()) {
321
352
if ((aggOrKey instanceof Alias alias && alias .child () instanceof AggregateFunction ) == false ) {
322
353
// This is a grouping key, not an aggregate function.
@@ -325,8 +356,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
325
356
}
326
357
Alias aggAlias = (Alias ) aggOrKey ;
327
358
AggregateFunction agg = (AggregateFunction ) aggAlias .child ();
359
+ boolean isMultiValued = SUPPORTED_MULTI_VALUED_AGGS .contains (agg .getClass ());
360
+ int bucketCount = isMultiValued ? 0 : BUCKET_COUNT ;
328
361
List <Alias > bucketedAggs = new ArrayList <>();
329
- for (int bucketId = -1 ; bucketId < BUCKET_COUNT ; bucketId ++) {
362
+ for (int bucketId = -1 ; bucketId < bucketCount ; bucketId ++) {
330
363
AggregateFunction bucketedAgg = agg .withFilter (
331
364
new MvContains (Source .EMPTY , bucketIdField .toAttribute (), Literal .integer (Source .EMPTY , bucketId )));
332
365
Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
@@ -345,12 +378,17 @@ private LogicalPlan approximatePlan(double sampleProbability) {
345
378
if (bucketId >= 0 ) {
346
379
bucketedAggs .add (correctedAggAlias );
347
380
}
348
- allBucketsNonNull = new And (Source .EMPTY , allBucketsNonNull , new IsNotNull (Source .EMPTY , correctedAggAlias .toAttribute ()));
381
+ allBucketsNonEmpty = new And (Source .EMPTY , allBucketsNonEmpty ,
382
+ agg instanceof Count
383
+ ? new NotEquals (Source .EMPTY , correctedAggAlias .toAttribute (), Literal .integer (Source .EMPTY , 0 ))
384
+ : new IsNotNull (Source .EMPTY , correctedAggAlias .toAttribute ()));
385
+ }
386
+ if (isMultiValued == false ) {
387
+ variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
349
388
}
350
- variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
351
389
}
352
390
plan = aggregate .with (addBucketId , aggregate .groupings (), aggregates );
353
- plan = new Filter (Source .EMPTY , plan , allBucketsNonNull );
391
+ plan = new Filter (Source .EMPTY , plan , allBucketsNonEmpty );
354
392
355
393
} else if (encounteredStats .get ()) {
356
394
switch (plan ) {
0 commit comments