Skip to content

Commit f8a5b34

Browse files
committed
Return exact result when counting all rows
1 parent bd0f61d commit f8a5b34

File tree

2 files changed

+97
-25
lines changed

2 files changed

+97
-25
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
2222
import org.elasticsearch.xpack.esql.core.tree.Source;
2323
import org.elasticsearch.xpack.esql.core.util.Holder;
24+
import org.elasticsearch.xpack.esql.core.util.StringUtils;
2425
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
2526
import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg;
2627
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
@@ -122,6 +123,8 @@
122123
*/
123124
public class Approximate {
124125

126+
public record QueryProperties(boolean hasNonCountAllAgg, boolean preservesRows) {}
127+
125128
public interface LogicalPlanRunner {
126129
void run(LogicalPlan plan, ActionListener<Result> listener);
127130
}
@@ -223,21 +226,31 @@ public interface LogicalPlanRunner {
223226

224227
private static final Logger logger = LogManager.getLogger(Approximate.class);
225228

229+
private static final AggregateFunction COUNT_ALL_ROWS = new Count(Source.EMPTY, Literal.keyword(Source.EMPTY, StringUtils.WILDCARD));
230+
226231
private final LogicalPlan logicalPlan;
227-
private final boolean preservesRows;
232+
private final QueryProperties queryProperties;
228233
private final LogicalPlanRunner runner;
229234

235+
private long sourceRowCount;
236+
230237
public Approximate(LogicalPlan logicalPlan, LogicalPlanRunner logicalPlanRunner) {
231238
this.logicalPlan = logicalPlan;
232-
this.preservesRows = verifyPlan(logicalPlan);
239+
this.queryProperties = verifyPlan(logicalPlan);
233240
this.runner = logicalPlanRunner;
234241
}
235242

236243
/**
237244
* Computes approximate results for the logical plan.
238245
*/
239246
public void approximate(ActionListener<Result> listener) {
240-
runner.run(sourceCountPlan(), sourceCountListener(runner, listener));
247+
if (queryProperties.hasNonCountAllAgg || queryProperties.preservesRows == false) {
248+
runner.run(sourceCountPlan(), sourceCountListener(listener));
249+
} else {
250+
// Counting all rows is fast for queries that preserve all rows, as it's returned from
251+
// Lucene's metadata. Approximation would only slow things down in this case.
252+
runner.run(logicalPlan, listener);
253+
}
241254
}
242255

243256
/**
@@ -246,7 +259,7 @@ public void approximate(ActionListener<Result> listener) {
246259
* @return whether the part of the query until the STATS command preserves all rows
247260
* @throws VerificationException if the plan is not suitable for approximation
248261
*/
249-
public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationException {
262+
public static QueryProperties verifyPlan(LogicalPlan logicalPlan) throws VerificationException {
250263
if (logicalPlan.preOptimized() == false) {
251264
throw new IllegalStateException("Expected pre-optimized plan");
252265
}
@@ -267,10 +280,15 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
267280
});
268281

269282
Holder<Boolean> encounteredStats = new Holder<>(false);
283+
Holder<Boolean> hasNonCountAllAgg = new Holder<>(false);
270284
Holder<Boolean> preservesRows = new Holder<>(true);
285+
271286
logicalPlan.transformUp(plan -> {
272287
if (encounteredStats.get() == false) {
273-
if (plan instanceof Aggregate) {
288+
if (plan instanceof Aggregate aggregate) {
289+
if (aggregate.groupings().isEmpty() == false) {
290+
hasNonCountAllAgg.set(true);
291+
}
274292
// Verify that the aggregate functions are supported.
275293
encounteredStats.set(true);
276294
plan.transformExpressionsOnly(AggregateFunction.class, aggFn -> {
@@ -286,6 +304,9 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
286304
)
287305
);
288306
}
307+
if (aggFn.equals(COUNT_ALL_ROWS) == false) {
308+
hasNonCountAllAgg.set(true);
309+
}
289310
return aggFn;
290311
});
291312
} else if (plan instanceof LeafPlan == false && ROW_PRESERVING_COMMANDS.contains(plan.getClass()) == false) {
@@ -301,7 +322,7 @@ public static boolean verifyPlan(LogicalPlan logicalPlan) throws VerificationExc
301322
return plan;
302323
});
303324

304-
return preservesRows.get();
325+
return new QueryProperties(hasNonCountAllAgg.get(), preservesRows.get());
305326
}
306327

307328
/**
@@ -315,12 +336,7 @@ private LogicalPlan sourceCountPlan() {
315336
LogicalPlan sourceCountPlan = logicalPlan.transformUp(plan -> {
316337
if (plan instanceof LeafPlan) {
317338
// Append the leaf plan by a STATS COUNT(*).
318-
plan = new Aggregate(
319-
Source.EMPTY,
320-
plan,
321-
List.of(),
322-
List.of(new Alias(Source.EMPTY, "$count", new Count(Source.EMPTY, Literal.keyword(Source.EMPTY, "*"))))
323-
);
339+
plan = new Aggregate(Source.EMPTY, plan, List.of(), List.of(new Alias(Source.EMPTY, "$count", COUNT_ALL_ROWS)));
324340
} else {
325341
// Strip everything after the leaf.
326342
plan = plan.children().getFirst();
@@ -337,15 +353,15 @@ private LogicalPlan sourceCountPlan() {
337353
* {@link Approximate#approximatePlan} or {@link Approximate#countPlan}
338354
* depending on whether the original query preserves all rows or not.
339355
*/
340-
private ActionListener<Result> sourceCountListener(LogicalPlanRunner runner, ActionListener<Result> listener) {
356+
private ActionListener<Result> sourceCountListener(ActionListener<Result> listener) {
341357
return listener.delegateFailureAndWrap((countListener, countResult) -> {
342-
long rowCount = rowCount(countResult);
343-
logger.debug("sourceCountPlan result: {} rows", rowCount);
344-
double sampleProbability = rowCount <= SAMPLE_ROW_COUNT ? 1.0 : (double) SAMPLE_ROW_COUNT / rowCount;
345-
if (preservesRows || sampleProbability == 1.0) {
358+
sourceRowCount = rowCount(countResult);
359+
logger.debug("sourceCountPlan result: {} rows", sourceRowCount);
360+
double sampleProbability = sourceRowCount <= SAMPLE_ROW_COUNT ? 1.0 : (double) SAMPLE_ROW_COUNT / sourceRowCount;
361+
if (queryProperties.preservesRows || sampleProbability == 1.0) {
346362
runner.run(approximatePlan(sampleProbability), listener);
347363
} else {
348-
runner.run(countPlan(sampleProbability), countListener(runner, sampleProbability, listener));
364+
runner.run(countPlan(sampleProbability), countListener(sampleProbability, listener));
349365
}
350366
});
351367
}
@@ -372,7 +388,7 @@ private LogicalPlan countPlan(double sampleProbability) {
372388
Source.EMPTY,
373389
aggregate.child(),
374390
List.of(),
375-
List.of(new Alias(Source.EMPTY, "$count", new Count(Source.EMPTY, Literal.keyword(Source.EMPTY, "*"))))
391+
List.of(new Alias(Source.EMPTY, "$count", COUNT_ALL_ROWS))
376392
);
377393
}
378394
} else {
@@ -392,13 +408,13 @@ private LogicalPlan countPlan(double sampleProbability) {
392408
* {@link Approximate#countPlan} depending on whether the current count is
393409
* sufficient.
394410
*/
395-
private ActionListener<Result> countListener(LogicalPlanRunner runner, double sampleProbability, ActionListener<Result> listener) {
411+
private ActionListener<Result> countListener(double sampleProbability, ActionListener<Result> listener) {
396412
return listener.delegateFailureAndWrap((countListener, countResult) -> {
397413
long rowCount = rowCount(countResult);
398414
logger.debug("countPlan result (p={}): {} rows", sampleProbability, rowCount);
399415
double newSampleProbability = sampleProbability * SAMPLE_ROW_COUNT / Math.max(1, rowCount);
400416
if (rowCount <= SAMPLE_ROW_COUNT / 2 && newSampleProbability < 1.0) {
401-
runner.run(countPlan(newSampleProbability), countListener(runner, newSampleProbability, listener));
417+
runner.run(countPlan(newSampleProbability), countListener(newSampleProbability, listener));
402418
} else {
403419
runner.run(approximatePlan(newSampleProbability), listener);
404420
}
@@ -508,9 +524,17 @@ private LogicalPlan approximatePlan(double sampleProbability) {
508524
// Replace the original aggregation by a sample-corrected one.
509525
Alias aggAlias = (Alias) aggOrKey;
510526
AggregateFunction aggFn = (AggregateFunction) aggAlias.child();
527+
528+
if (aggFn.equals(COUNT_ALL_ROWS) && aggregate.groupings().isEmpty() && queryProperties.preservesRows) {
529+
// If the query is preserving all rows, and the aggregation function is
530+
// counting all rows, we know the exact result without sampling.
531+
aggregates.add(aggAlias.replaceChild(Literal.fromLong(Source.EMPTY, sourceRowCount)));
532+
continue;
533+
}
534+
511535
aggregates.add(aggAlias.replaceChild(correctForSampling(aggFn, sampleProbability)));
512536

513-
if (SUPPORTED_MULTIVALUED_AGGS.contains(aggFn.getClass()) == false) {
537+
if (SUPPORTED_SINGLE_VALUED_AGGS.contains(aggFn.getClass())) {
514538
// For the supported single-valued aggregations, add buckets with sampled
515539
// values, that will be used to compute a confidence interval.
516540
// For multivalued aggregations, confidence intervals do not make sense.

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximate/ApproximateTests.java

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import static org.hamcrest.CoreMatchers.allOf;
4949
import static org.hamcrest.CoreMatchers.not;
5050
import static org.hamcrest.Matchers.equalTo;
51+
import static org.hamcrest.Matchers.greaterThan;
5152
import static org.hamcrest.Matchers.hasSize;
5253
import static org.mockito.Mockito.mock;
5354

@@ -234,23 +235,64 @@ public void testCountPlan_smallDataBeforeFiltering() throws Exception {
234235
assertThat(runner.invocations.get(1), allOf(hasFilter("emp_no"), not(hasSample())));
235236
}
236237

237-
public void testApproximatePlan_createsConfidenceInterval() throws Exception {
238+
public void testApproximate_countAllRows() throws Exception {
238239
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
239-
Approximate approximate = createApproximate("FROM test | STATS SUM(emp_no)", runner);
240+
Approximate approximate = createApproximate("FROM test | STATS COUNT(*)", runner);
241+
approximate.approximate(TestRunner.resultCloser);
242+
assertThat(runner.invocations, hasSize(1));
243+
}
244+
245+
public void testApproximate_countAllRows_withFiltering() throws Exception {
246+
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
247+
Approximate approximate = createApproximate("FROM test | WHERE emp_no > 10 | STATS COUNT(*)", runner);
248+
approximate.approximate(TestRunner.resultCloser);
249+
assertThat(runner.invocations, hasSize(greaterThan(1)));
250+
}
251+
252+
public void testApproximate_countAllRows_withGrouping() throws Exception {
253+
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
254+
Approximate approximate = createApproximate("FROM test | STATS COUNT(*) BY emp_no", runner);
255+
approximate.approximate(TestRunner.resultCloser);
256+
assertThat(runner.invocations, hasSize(greaterThan(1)));
257+
}
258+
259+
public void testApproximatePlan_createsConfidenceInterval_withoutGrouping() throws Exception {
260+
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
261+
Approximate approximate = createApproximate("FROM test | STATS COUNT(), SUM(emp_no)", runner);
262+
approximate.approximate(TestRunner.resultCloser);
263+
// One pass is needed to get the number of rows, and approximation is executed immediately
264+
// after that with the correct sample probability.
265+
assertThat(runner.invocations, hasSize(2));
266+
267+
LogicalPlan approximatePlan = runner.invocations.get(1);
268+
assertThat(approximatePlan, hasSample(1e-4));
269+
// Counting all rows is exact, so no confidence interval is output.
270+
assertThat(approximatePlan, not(hasEval("CONFIDENCE_INTERVAL(COUNT())")));
271+
assertThat(approximatePlan, not(hasEval("RELIABLE(COUNT())")));
272+
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(SUM(emp_no))"));
273+
assertThat(approximatePlan, hasEval("RELIABLE(SUM(emp_no))"));
274+
}
275+
276+
public void testApproximatePlan_createsConfidenceInterval_withGrouping() throws Exception {
277+
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
278+
Approximate approximate = createApproximate("FROM test | STATS COUNT(), SUM(emp_no) BY emp_no", runner);
240279
approximate.approximate(TestRunner.resultCloser);
241280
// One pass is needed to get the number of rows, and approximation is executed immediately
242281
// after that with the correct sample probability.
243282
assertThat(runner.invocations, hasSize(2));
244283

245284
LogicalPlan approximatePlan = runner.invocations.get(1);
246285
assertThat(approximatePlan, hasSample(1e-4));
286+
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(COUNT())"));
287+
assertThat(approximatePlan, hasEval("RELIABLE(COUNT())"));
247288
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(SUM(emp_no))"));
289+
assertThat(approximatePlan, hasEval("RELIABLE(SUM(emp_no))"));
248290
}
249291

250292
public void testApproximatePlan_dependentConfidenceIntervals() throws Exception {
251293
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
252294
Approximate approximate = createApproximate(
253-
"FROM test | STATS x=COUNT() | EVAL a=x*x, b=7, c=TO_STRING(x), d=MV_APPEND(x, 1::LONG), e=a+POW(b, 2)",
295+
"FROM test | STATS x=SUM(emp_no) | EVAL a=x*x, b=7, c=TO_STRING(x), d=MV_APPEND(x, 1::LONG), e=a+POW(b, 2)",
254296
runner
255297
);
256298
approximate.approximate(TestRunner.resultCloser);
@@ -261,11 +303,17 @@ public void testApproximatePlan_dependentConfidenceIntervals() throws Exception
261303
LogicalPlan approximatePlan = runner.invocations.get(1);
262304
assertThat(approximatePlan, hasPlan(Sample.class, s -> Foldables.literalValueOf(s.probability()).equals(1e-4)));
263305
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(x)"));
306+
assertThat(approximatePlan, hasEval("RELIABLE(x)"));
264307
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(a)"));
308+
assertThat(approximatePlan, hasEval("RELIABLE(a)"));
265309
assertThat(approximatePlan, not(hasEval("CONFIDENCE_INTERVAL(b)")));
310+
assertThat(approximatePlan, not(hasEval("RELIABLE(b)")));
266311
assertThat(approximatePlan, not(hasEval("CONFIDENCE_INTERVAL(c)")));
312+
assertThat(approximatePlan, not(hasEval("RELIABLE(c)")));
267313
assertThat(approximatePlan, not(hasEval("CONFIDENCE_INTERVAL(d)")));
314+
assertThat(approximatePlan, not(hasEval("RELIABLE(d)")));
268315
assertThat(approximatePlan, hasEval("CONFIDENCE_INTERVAL(e)"));
316+
assertThat(approximatePlan, hasEval("RELIABLE(e)"));
269317
}
270318

271319
private Matcher<? super LogicalPlan> hasFilter(String field) {

0 commit comments

Comments
 (0)