1010import org .apache .logging .log4j .LogManager ;
1111import org .apache .logging .log4j .Logger ;
1212import org .elasticsearch .action .ActionListener ;
13+ import org .elasticsearch .common .util .set .Sets ;
1314import org .elasticsearch .compute .data .LongBlock ;
1415import org .elasticsearch .compute .data .Page ;
1516import org .elasticsearch .xpack .esql .VerificationException ;
132133 */
133134public class Approximate {
134135
135- public record QueryProperties (boolean preservesRows ) {}
136+ public record QueryProperties (boolean canDecreaseRowCount , boolean canIncreaseRowCount ) {}
136137
137138 /**
138139 * These processing commands are supported.
@@ -168,6 +169,7 @@ public record QueryProperties(boolean preservesRows) {}
168169 */
169170 private static final Set <Class <? extends LogicalPlan >> ROW_PRESERVING_COMMANDS = Set .of (
170171 ChangePoint .class ,
172+ Completion .class ,
171173 Dissect .class ,
172174 Drop .class ,
173175 Enrich .class ,
@@ -178,7 +180,25 @@ public record QueryProperties(boolean preservesRows) {}
178180 Keep .class ,
179181 OrderBy .class ,
180182 Project .class ,
181- Rename .class
183+ RegexExtract .class ,
184+ Rename .class ,
185+ Rerank .class
186+ );
187+
188+ /**
189+ * These commands never increase the number of all rows, making it easier to predict the number of output rows.
190+ */
191+ private static final Set <Class <? extends LogicalPlan >> ROW_NON_INCREASING_COMMANDS = Sets .union (
192+ Set .of (Filter .class , Limit .class , Sample .class , TopN .class ),
193+ ROW_PRESERVING_COMMANDS
194+ );
195+
196+ /**
197+ * These commands never decrease the number of all rows, making it easier to predict the number of output rows.
198+ */
199+ private static final Set <Class <? extends LogicalPlan >> ROW_NON_DECREASING_COMMANDS = Sets .union (
200+ Set .of (MvExpand .class ),
201+ ROW_PRESERVING_COMMANDS
182202 );
183203
184204 /**
@@ -290,7 +310,8 @@ public static QueryProperties verifyPlan(LogicalPlan logicalPlan) throws Verific
290310 });
291311
292312 Holder <Boolean > encounteredStats = new Holder <>(false );
293- Holder <Boolean > preservesRows = new Holder <>(true );
313+ Holder <Boolean > canIncreaseRowCount = new Holder <>(false );
314+ Holder <Boolean > canDecreaseRowCount = new Holder <>(false );
294315
295316 logicalPlan .transformUp (plan -> {
296317 if (encounteredStats .get () == false ) {
@@ -312,9 +333,13 @@ public static QueryProperties verifyPlan(LogicalPlan logicalPlan) throws Verific
312333 }
313334 return aggFn ;
314335 });
315- } else if (plan instanceof LeafPlan == false && ROW_PRESERVING_COMMANDS .contains (plan .getClass ()) == false ) {
316- // Keep track of whether the plan until the STATS preserves all rows.
317- preservesRows .set (false );
336+ } else if (plan instanceof LeafPlan == false ) {
337+ if (ROW_NON_DECREASING_COMMANDS .contains (plan .getClass ()) == false ) {
338+ canDecreaseRowCount .set (true );
339+ }
340+ if (ROW_NON_INCREASING_COMMANDS .contains (plan .getClass ()) == false ) {
341+ canIncreaseRowCount .set (true );
342+ }
318343 }
319344 } else {
320345 // Multiple STATS commands are not supported.
@@ -325,7 +350,7 @@ public static QueryProperties verifyPlan(LogicalPlan logicalPlan) throws Verific
325350 return plan ;
326351 });
327352
328- return new QueryProperties (preservesRows .get ());
353+ return new QueryProperties (canDecreaseRowCount . get (), canIncreaseRowCount .get ());
329354 }
330355
331356 /**
@@ -348,11 +373,12 @@ private ActionListener<Result> approximateListener(ActionListener<Result> listen
348373 return new ActionListener <>() {
349374 @ Override
350375 public void onResponse (Result result ) {
351- boolean esStatsQueryExecuted = result .executionInfo () != null && result .executionInfo ().clusterInfo .values ()
352- .stream ()
353- .noneMatch (
354- cluster -> cluster .getFailures ().stream ().anyMatch (e -> e .getCause () instanceof UnsupportedOperationException )
355- );
376+ boolean esStatsQueryExecuted = result .executionInfo () != null
377+ && result .executionInfo ().clusterInfo .values ()
378+ .stream ()
379+ .noneMatch (
380+ cluster -> cluster .getFailures ().stream ().anyMatch (e -> e .getCause () instanceof UnsupportedOperationException )
381+ );
356382 if (esStatsQueryExecuted ) {
357383 logger .debug ("not approximating stats query" );
358384 listener .onResponse (result );
@@ -406,9 +432,15 @@ private ActionListener<Result> sourceCountListener(ActionListener<Result> listen
406432 sourceRowCount = rowCount (countResult );
407433 logger .debug ("sourceCountPlan result: {} rows" , sourceRowCount );
408434 double sampleProbability = sourceRowCount <= SAMPLE_ROW_COUNT ? 1.0 : (double ) SAMPLE_ROW_COUNT / sourceRowCount ;
409- if (queryProperties .preservesRows ) {
435+ if (queryProperties .canIncreaseRowCount == false && sampleProbability == 1.0 ) {
436+ // If the query cannot increase the number of rows, and the sample probability is 1.0,
437+ // we can directly approximate without sampling.
438+ runner .run (toPhysicalPlan .apply (logicalPlan ), configuration , foldContext , listener );
439+ } else if (queryProperties .canIncreaseRowCount == false && queryProperties .canDecreaseRowCount == false ) {
440+ // If the query preserves all rows, we can directly approximate with the sample probability.
410441 runner .run (toPhysicalPlan .apply (approximatePlan (sampleProbability )), configuration , foldContext , listener );
411442 } else {
443+ // Otherwise, we need to sample the number of rows first to obtain a good sample probability.
412444 runner .run (
413445 toPhysicalPlan .apply (countPlan (sampleProbability )),
414446 configuration ,
@@ -585,7 +617,10 @@ private LogicalPlan approximatePlan(double sampleProbability) {
585617 Alias aggAlias = (Alias ) aggOrKey ;
586618 AggregateFunction aggFn = (AggregateFunction ) aggAlias .child ();
587619
588- if (aggFn .equals (COUNT_ALL_ROWS ) && aggregate .groupings ().isEmpty () && queryProperties .preservesRows ) {
620+ if (aggFn .equals (COUNT_ALL_ROWS )
621+ && aggregate .groupings ().isEmpty ()
622+ && queryProperties .canDecreaseRowCount == false
623+ && queryProperties .canIncreaseRowCount == false ) {
589624 // If the query is preserving all rows, and the aggregation function is
590625 // counting all rows, we know the exact result without sampling.
591626 aggregates .add (aggAlias .replaceChild (Literal .fromLong (Source .EMPTY , sourceRowCount )));
@@ -746,7 +781,9 @@ private LogicalPlan approximatePlan(double sampleProbability) {
746781 default -> throw new IllegalStateException ("unexpected data type [" + output .dataType () + "]" );
747782 };
748783 confidenceIntervalsAndReliable .add (
749- new Alias (Source .EMPTY , "CONFIDENCE_INTERVAL(" + output .name () + ")" ,
784+ new Alias (
785+ Source .EMPTY ,
786+ "CONFIDENCE_INTERVAL(" + output .name () + ")" ,
750787 new MvSlice (Source .EMPTY , confidenceInterval , Literal .integer (Source .EMPTY , 0 ), Literal .integer (Source .EMPTY , 1 ))
751788 )
752789 );
@@ -756,7 +793,12 @@ private LogicalPlan approximatePlan(double sampleProbability) {
756793 "RELIABLE(" + output .name () + ")" ,
757794 new GreaterThanOrEqual (
758795 Source .EMPTY ,
759- new MvSlice (Source .EMPTY , confidenceInterval , Literal .integer (Source .EMPTY , 2 ), Literal .integer (Source .EMPTY , 2 )),
796+ new MvSlice (
797+ Source .EMPTY ,
798+ confidenceInterval ,
799+ Literal .integer (Source .EMPTY , 2 ),
800+ Literal .integer (Source .EMPTY , 2 )
801+ ),
760802 Literal .fromDouble (Source .EMPTY , 0.5 )
761803 )
762804 )
0 commit comments