2121import org .elasticsearch .xpack .esql .core .expression .NamedExpression ;
2222import org .elasticsearch .xpack .esql .core .tree .Source ;
2323import org .elasticsearch .xpack .esql .core .util .Holder ;
24+ import org .elasticsearch .xpack .esql .core .util .StringUtils ;
2425import org .elasticsearch .xpack .esql .expression .function .aggregate .AggregateFunction ;
2526import org .elasticsearch .xpack .esql .expression .function .aggregate .Avg ;
2627import org .elasticsearch .xpack .esql .expression .function .aggregate .Count ;
122123 */
123124public class Approximate {
124125
126+ public record QueryProperties (boolean hasNonCountAllAgg , boolean preservesRows ) {}
127+
125128 public interface LogicalPlanRunner {
126129 void run (LogicalPlan plan , ActionListener <Result > listener );
127130 }
@@ -223,21 +226,31 @@ public interface LogicalPlanRunner {
223226
224227 private static final Logger logger = LogManager .getLogger (Approximate .class );
225228
229+ private static final AggregateFunction COUNT_ALL_ROWS = new Count (Source .EMPTY , Literal .keyword (Source .EMPTY , StringUtils .WILDCARD ));
230+
226231 private final LogicalPlan logicalPlan ;
227- private final boolean preservesRows ;
232+ private final QueryProperties queryProperties ;
228233 private final LogicalPlanRunner runner ;
229234
235+ private long sourceRowCount ;
236+
230237 public Approximate (LogicalPlan logicalPlan , LogicalPlanRunner logicalPlanRunner ) {
231238 this .logicalPlan = logicalPlan ;
232- this .preservesRows = verifyPlan (logicalPlan );
239+ this .queryProperties = verifyPlan (logicalPlan );
233240 this .runner = logicalPlanRunner ;
234241 }
235242
236243 /**
237244 * Computes approximate results for the logical plan.
238245 */
239246 public void approximate (ActionListener <Result > listener ) {
240- runner .run (sourceCountPlan (), sourceCountListener (runner , listener ));
247+ if (queryProperties .hasNonCountAllAgg || queryProperties .preservesRows == false ) {
248+ runner .run (sourceCountPlan (), sourceCountListener (listener ));
249+ } else {
250+ // Counting all rows is fast for queries that preserve all rows, as it's returned from
251+ // Lucene's metadata. Approximation would only slow things down in this case.
252+ runner .run (logicalPlan , listener );
253+ }
241254 }
242255
243256 /**
@@ -246,7 +259,7 @@ public void approximate(ActionListener<Result> listener) {
246259 * @return whether the part of the query until the STATS command preserves all rows
247260 * @throws VerificationException if the plan is not suitable for approximation
248261 */
249- public static boolean verifyPlan (LogicalPlan logicalPlan ) throws VerificationException {
262+ public static QueryProperties verifyPlan (LogicalPlan logicalPlan ) throws VerificationException {
250263 if (logicalPlan .preOptimized () == false ) {
251264 throw new IllegalStateException ("Expected pre-optimized plan" );
252265 }
@@ -267,10 +280,15 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
267280 });
268281
269282 Holder <Boolean > encounteredStats = new Holder <>(false );
283+ Holder <Boolean > hasNonCountAllAgg = new Holder <>(false );
270284 Holder <Boolean > preservesRows = new Holder <>(true );
285+
271286 logicalPlan .transformUp (plan -> {
272287 if (encounteredStats .get () == false ) {
273- if (plan instanceof Aggregate ) {
288+ if (plan instanceof Aggregate aggregate ) {
289+ if (aggregate .groupings ().isEmpty () == false ) {
290+ hasNonCountAllAgg .set (true );
291+ }
274292 // Verify that the aggregate functions are supported.
275293 encounteredStats .set (true );
276294 plan .transformExpressionsOnly (AggregateFunction .class , aggFn -> {
@@ -286,6 +304,9 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
286304 )
287305 );
288306 }
307+ if (aggFn .equals (COUNT_ALL_ROWS ) == false ) {
308+ hasNonCountAllAgg .set (true );
309+ }
289310 return aggFn ;
290311 });
291312 } else if (plan instanceof LeafPlan == false && ROW_PRESERVING_COMMANDS .contains (plan .getClass ()) == false ) {
@@ -301,7 +322,7 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
301322 return plan ;
302323 });
303324
304- return preservesRows .get ();
325+ return new QueryProperties ( hasNonCountAllAgg . get (), preservesRows .get () );
305326 }
306327
307328 /**
@@ -315,12 +336,7 @@ private LogicalPlan sourceCountPlan() {
315336 LogicalPlan sourceCountPlan = logicalPlan .transformUp (plan -> {
316337 if (plan instanceof LeafPlan ) {
317338 // Append the leaf plan by a STATS COUNT(*).
318- plan = new Aggregate (
319- Source .EMPTY ,
320- plan ,
321- List .of (),
322- List .of (new Alias (Source .EMPTY , "$count" , new Count (Source .EMPTY , Literal .keyword (Source .EMPTY , "*" ))))
323- );
339+ plan = new Aggregate (Source .EMPTY , plan , List .of (), List .of (new Alias (Source .EMPTY , "$count" , COUNT_ALL_ROWS )));
324340 } else {
325341 // Strip everything after the leaf.
326342 plan = plan .children ().getFirst ();
@@ -337,15 +353,15 @@ private LogicalPlan sourceCountPlan() {
337353 * {@link Approximate#approximatePlan} or {@link Approximate#countPlan}
338354 * depending on whether the original query preserves all rows or not.
339355 */
340- private ActionListener <Result > sourceCountListener (LogicalPlanRunner runner , ActionListener <Result > listener ) {
356+ private ActionListener <Result > sourceCountListener (ActionListener <Result > listener ) {
341357 return listener .delegateFailureAndWrap ((countListener , countResult ) -> {
342- long rowCount = rowCount (countResult );
343- logger .debug ("sourceCountPlan result: {} rows" , rowCount );
344- double sampleProbability = rowCount <= SAMPLE_ROW_COUNT ? 1.0 : (double ) SAMPLE_ROW_COUNT / rowCount ;
345- if (preservesRows || sampleProbability == 1.0 ) {
358+ sourceRowCount = rowCount (countResult );
359+ logger .debug ("sourceCountPlan result: {} rows" , sourceRowCount );
360+ double sampleProbability = sourceRowCount <= SAMPLE_ROW_COUNT ? 1.0 : (double ) SAMPLE_ROW_COUNT / sourceRowCount ;
361+ if (queryProperties . preservesRows || sampleProbability == 1.0 ) {
346362 runner .run (approximatePlan (sampleProbability ), listener );
347363 } else {
348- runner .run (countPlan (sampleProbability ), countListener (runner , sampleProbability , listener ));
364+ runner .run (countPlan (sampleProbability ), countListener (sampleProbability , listener ));
349365 }
350366 });
351367 }
@@ -372,7 +388,7 @@ private LogicalPlan countPlan(double sampleProbability) {
372388 Source .EMPTY ,
373389 aggregate .child (),
374390 List .of (),
375- List .of (new Alias (Source .EMPTY , "$count" , new Count ( Source . EMPTY , Literal . keyword ( Source . EMPTY , "*" )) ))
391+ List .of (new Alias (Source .EMPTY , "$count" , COUNT_ALL_ROWS ))
376392 );
377393 }
378394 } else {
@@ -392,13 +408,13 @@ private LogicalPlan countPlan(double sampleProbability) {
392408 * {@link Approximate#countPlan} depending on whether the current count is
393409 * sufficient.
394410 */
395- private ActionListener <Result > countListener (LogicalPlanRunner runner , double sampleProbability , ActionListener <Result > listener ) {
411+ private ActionListener <Result > countListener (double sampleProbability , ActionListener <Result > listener ) {
396412 return listener .delegateFailureAndWrap ((countListener , countResult ) -> {
397413 long rowCount = rowCount (countResult );
398414 logger .debug ("countPlan result (p={}): {} rows" , sampleProbability , rowCount );
399415 double newSampleProbability = sampleProbability * SAMPLE_ROW_COUNT / Math .max (1 , rowCount );
400416 if (rowCount <= SAMPLE_ROW_COUNT / 2 && newSampleProbability < 1.0 ) {
401- runner .run (countPlan (newSampleProbability ), countListener (runner , newSampleProbability , listener ));
417+ runner .run (countPlan (newSampleProbability ), countListener (newSampleProbability , listener ));
402418 } else {
403419 runner .run (approximatePlan (newSampleProbability ), listener );
404420 }
@@ -508,9 +524,17 @@ private LogicalPlan approximatePlan(double sampleProbability) {
508524 // Replace the original aggregation by a sample-corrected one.
509525 Alias aggAlias = (Alias ) aggOrKey ;
510526 AggregateFunction aggFn = (AggregateFunction ) aggAlias .child ();
527+
528+ if (aggFn .equals (COUNT_ALL_ROWS ) && aggregate .groupings ().isEmpty () && queryProperties .preservesRows ) {
529+ // If the query is preserving all rows, and the aggregation function is
530+ // counting all rows, we know the exact result without sampling.
531+ aggregates .add (aggAlias .replaceChild (Literal .fromLong (Source .EMPTY , sourceRowCount )));
532+ continue ;
533+ }
534+
511535 aggregates .add (aggAlias .replaceChild (correctForSampling (aggFn , sampleProbability )));
512536
513- if (SUPPORTED_MULTIVALUED_AGGS .contains (aggFn .getClass ()) == false ) {
537+ if (SUPPORTED_SINGLE_VALUED_AGGS .contains (aggFn .getClass ())) {
514538 // For the supported single-valued aggregations, add buckets with sampled
515539 // values, that will be used to compute a confidence interval.
516540 // For multivalued aggregations, confidence intervals do not make sense.
0 commit comments