2222import org .elasticsearch .xpack .esql .core .tree .Source ;
2323import org .elasticsearch .xpack .esql .core .util .Holder ;
2424import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
25+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Avg ;
2526import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
27+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Median ;
28+ import org .elasticsearch .xpack .esql .expression .function .aggregate .MedianAbsoluteDeviation ;
29+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Percentile ;
30+ import org .elasticsearch .xpack .esql .expression .function .aggregate .StdDev ;
31+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Sum ;
32+ import org .elasticsearch .xpack .esql .expression .function .aggregate .WeightedAvg ;
2633import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .ConfidenceInterval ;
2734import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvAppend ;
2835import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .MvContains ;
2936import org .elasticsearch .xpack .esql .expression .function .scalar .random .Random ;
3037import org .elasticsearch .xpack .esql .expression .predicate .logical .And ;
3138import org .elasticsearch .xpack .esql .expression .predicate .nulls .IsNotNull ;
39+ import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .NotEquals ;
3240import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
3341import org .elasticsearch .xpack .esql .plan .logical .ChangePoint ;
3442import org .elasticsearch .xpack .esql .plan .logical .Dissect ;
5260import java .util .ArrayList ;
5361import java .util .HashMap ;
5462import java .util .List ;
63+ import java .util .Locale ;
5564import java .util .Map ;
5665import java .util .Set ;
5766import java .util .stream .Collectors ;
@@ -93,7 +102,6 @@ public interface LogicalPlanRunner {
93102 void run (LogicalPlan plan , ActionListener <Result > listener );
94103 }
95104
96-
97105 /**
98106 * These commands preserve all rows, making it easy to predict the number of output rows.
99107 */
@@ -112,10 +120,25 @@ public interface LogicalPlanRunner {
112120 Rename .class
113121 );
114122
123+ private static final Set <Class <? extends AggregateFunction >> SUPPORTED_SINGLE_VALUED_AGGS = Set .of (
124+ Avg .class ,
125+ Count .class ,
126+ Median .class ,
127+ MedianAbsoluteDeviation .class ,
128+ Percentile .class ,
129+ StdDev .class ,
130+ Sum .class ,
131+ WeightedAvg .class
132+ );
133+
134+ private static final Set <Class <? extends AggregateFunction >> SUPPORTED_MULTI_VALUED_AGGS = Set .of (
135+ org .elasticsearch .xpack .esql .expression .function .aggregate .Sample .class
136+ );
137+
115138 // TODO: find a good default value, or alternative ways of setting it
116139 private static final int SAMPLE_ROW_COUNT = 100000 ;
117140
118- private static final int BUCKET_COUNT = 3 ;
141+ private static final int BUCKET_COUNT = 16 ;
119142
120143 private static final Logger logger = LogManager .getLogger (Approximate .class );
121144
@@ -153,7 +176,7 @@ private boolean verifyPlan() {
153176 logicalPlan .forEachUp (plan -> {
154177 if (plan instanceof LeafPlan == false && plan instanceof UnaryPlan == false ) {
155178 throw new VerificationException (
156- List .of (Failure .fail (plan , "query with [" + plan .nodeName () + "] cannot be approximated" ))
179+ List .of (Failure .fail (plan , "query with [" + plan .nodeName (). toUpperCase ( Locale . ROOT ) + "] cannot be approximated" ))
157180 );
158181 }
159182 });
@@ -162,8 +185,16 @@ private boolean verifyPlan() {
162185 Holder <Boolean > hasFilters = new Holder <>(false );
163186 logicalPlan .transformUp (plan -> {
164187 if (encounteredStats .get () == false ) {
165- if (plan instanceof Aggregate ) {
188+ if (plan instanceof Aggregate aggregate ) {
166189 encounteredStats .set (true );
190+ plan .transformExpressionsOnly (AggregateFunction .class , aggFn -> {
191+ if (SUPPORTED_SINGLE_VALUED_AGGS .contains (aggFn .getClass ()) == false && SUPPORTED_MULTI_VALUED_AGGS .contains (aggFn .getClass ()) == false ) {
192+ throw new VerificationException (
193+ List .of (Failure .fail (aggFn , "aggregation function [" + aggFn .nodeName ().toUpperCase () + "] cannot be approximated" ))
194+ );
195+ }
196+ return aggFn ;
197+ });
167198 } else if (ROW_PRESERVING_COMMANDS .contains (plan .getClass ()) == false ) {
168199 hasFilters .set (true );
169200 }
@@ -316,7 +347,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
316347
317348 Eval addBucketId = new Eval (Source .EMPTY , aggregate .child (), List .of (bucketIdField ));
318349 List <NamedExpression > aggregates = new ArrayList <>();
319- Expression allBucketsNonNull = Literal .TRUE ;
350+ Expression allBucketsNonEmpty = Literal .TRUE ;
320351 for (NamedExpression aggOrKey : aggregate .aggregates ()) {
321352 if ((aggOrKey instanceof Alias alias && alias .child () instanceof AggregateFunction ) == false ) {
322353 // This is a grouping key, not an aggregate function.
@@ -325,8 +356,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
325356 }
326357 Alias aggAlias = (Alias ) aggOrKey ;
327358 AggregateFunction agg = (AggregateFunction ) aggAlias .child ();
359+ boolean isMultiValued = SUPPORTED_MULTI_VALUED_AGGS .contains (agg .getClass ());
360+ int bucketCount = isMultiValued ? 0 : BUCKET_COUNT ;
328361 List <Alias > bucketedAggs = new ArrayList <>();
329- for (int bucketId = -1 ; bucketId < BUCKET_COUNT ; bucketId ++) {
362+ for (int bucketId = -1 ; bucketId < bucketCount ; bucketId ++) {
330363 AggregateFunction bucketedAgg = agg .withFilter (
331364 new MvContains (Source .EMPTY , bucketIdField .toAttribute (), Literal .integer (Source .EMPTY , bucketId )));
332365 Expression correctedAgg = bucketedAgg instanceof NeedsSampleCorrection nsc
@@ -345,12 +378,17 @@ private LogicalPlan approximatePlan(double sampleProbability) {
345378 if (bucketId >= 0 ) {
346379 bucketedAggs .add (correctedAggAlias );
347380 }
348- allBucketsNonNull = new And (Source .EMPTY , allBucketsNonNull , new IsNotNull (Source .EMPTY , correctedAggAlias .toAttribute ()));
381+ allBucketsNonEmpty = new And (Source .EMPTY , allBucketsNonEmpty ,
382+ agg instanceof Count
383+ ? new NotEquals (Source .EMPTY , correctedAggAlias .toAttribute (), Literal .integer (Source .EMPTY , 0 ))
384+ : new IsNotNull (Source .EMPTY , correctedAggAlias .toAttribute ()));
385+ }
386+ if (isMultiValued == false ) {
387+ variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
349388 }
350- variablesWithConfidenceInterval .put (aggOrKey .id (), bucketedAggs );
351389 }
352390 plan = aggregate .with (addBucketId , aggregate .groupings (), aggregates );
353- plan = new Filter (Source .EMPTY , plan , allBucketsNonNull );
391+ plan = new Filter (Source .EMPTY , plan , allBucketsNonEmpty );
354392
355393 } else if (encounteredStats .get ()) {
356394 switch (plan ) {
0 commit comments