2828import org .elasticsearch .xpack .esql .core .expression .ReferenceAttribute ;
2929import org .elasticsearch .xpack .esql .core .tree .Source ;
3030import org .elasticsearch .xpack .esql .expression .Foldables ;
31+ import org .elasticsearch .xpack .esql .expression .function .aggregate .Sum ;
3132import org .elasticsearch .xpack .esql .inference .InferenceService ;
3233import org .elasticsearch .xpack .esql .optimizer .LogicalOptimizerContext ;
3334import org .elasticsearch .xpack .esql .optimizer .LogicalPlanOptimizer ;
3435import org .elasticsearch .xpack .esql .optimizer .LogicalPlanPreOptimizer ;
3536import org .elasticsearch .xpack .esql .optimizer .LogicalPreOptimizerContext ;
3637import org .elasticsearch .xpack .esql .parser .EsqlParser ;
3738import org .elasticsearch .xpack .esql .parser .QueryParams ;
39+ import org .elasticsearch .xpack .esql .plan .logical .Aggregate ;
3840import org .elasticsearch .xpack .esql .plan .logical .Eval ;
3941import org .elasticsearch .xpack .esql .plan .logical .Filter ;
4042import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
@@ -233,9 +235,9 @@ public void testCountPlan_noData() throws Exception {
233235 // - one pass to get the total number of rows (which is zero)
234236 // - one pass to execute the original query
235237 assertThat (runner .invocations , hasSize (3 ));
236- assertThat (runner .invocations .get (0 ), allOf (not (hasSample ())));
237- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ())));
238- assertThat (runner .invocations .get (2 ), allOf (not (hasSample ())));
238+ assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasSum () ));
239+ assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not ( hasSum ()) ));
240+ assertThat (runner .invocations .get (2 ), allOf (not (hasSample ()), hasSum () ));
239241 }
240242
241243 public void testCountPlan_largeDataNoFilters () throws Exception {
@@ -247,9 +249,9 @@ public void testCountPlan_largeDataNoFilters() throws Exception {
247249 // - one pass to get the total number of rows (which determines the sample probability)
248250 // - one pass to approximate the query
249251 assertThat (runner .invocations , hasSize (3 ));
250- assertThat (runner .invocations .get (0 ), allOf (not (hasSample ())));
251- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ())));
252- assertThat (runner .invocations .get (2 ), allOf (hasSample (1e-4 )));
252+ assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasSum () ));
253+ assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not ( hasSum ()) ));
254+ assertThat (runner .invocations .get (2 ), allOf (hasSample (1e-4 ), hasSum () ));
253255 }
254256
255257 public void testCountPlan_smallDataNoFilters () throws Exception {
@@ -261,9 +263,9 @@ public void testCountPlan_smallDataNoFilters() throws Exception {
261263 // - one pass to get the total number of rows (which is small)
262264 // - one pass to execute the original query
263265 assertThat (runner .invocations , hasSize (3 ));
264- assertThat (runner .invocations .get (0 ), allOf (not (hasSample ())));
265- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ())));
266- assertThat (runner .invocations .get (2 ), allOf (not (hasSample ())));
266+ assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasSum () ));
267+ assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not ( hasSum ()) ));
268+ assertThat (runner .invocations .get (2 ), allOf (not (hasSample ()), hasSum () ));
267269 }
268270
269271 public void testCountPlan_largeDataAfterFiltering () throws Exception {
@@ -276,11 +278,11 @@ public void testCountPlan_largeDataAfterFiltering() throws Exception {
276278 // - two passes to get the number of filtered rows (which determines the sample probability)
277279 // - one pass to approximate the query
278280 assertThat (runner .invocations , hasSize (5 ));
279- assertThat (runner .invocations .get (0 ), allOf (hasFilter ("emp_no" ), not (hasSample ())));
280- assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("emp_no" )), not (hasSample ())));
281- assertThat (runner .invocations .get (2 ), allOf (hasFilter ("emp_no" ), hasSample (1e-8 )));
282- assertThat (runner .invocations .get (3 ), allOf (hasFilter ("emp_no" ), hasSample (1e-5 )));
283- assertThat (runner .invocations .get (4 ), allOf (hasFilter ("emp_no" ), hasSample (1e-4 )));
281+ assertThat (runner .invocations .get (0 ), allOf (hasFilter ("emp_no" ), not (hasSample ()), hasSum () ));
282+ assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("emp_no" )), not (hasSample ()), not ( hasSum ()) ));
283+ assertThat (runner .invocations .get (2 ), allOf (hasFilter ("emp_no" ), hasSample (1e-8 ), not ( hasSum ()) ));
284+ assertThat (runner .invocations .get (3 ), allOf (hasFilter ("emp_no" ), hasSample (1e-5 ), not ( hasSum ()) ));
285+ assertThat (runner .invocations .get (4 ), allOf (hasFilter ("emp_no" ), hasSample (1e-4 ), hasSum () ));
284286 }
285287
286288 public void testCountPlan_smallDataAfterFiltering () throws Exception {
@@ -293,14 +295,14 @@ public void testCountPlan_smallDataAfterFiltering() throws Exception {
293295 // - three passes to get the number of filtered rows (which is small)
294296 // - one pass to execute the original query
295297 assertThat (runner .invocations , hasSize (8 ));
296- assertThat (runner .invocations .get (0 ), allOf (hasFilter ("emp_no" ), not (hasSample ())));
297- assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("emp_no" )), not (hasSample ())));
298- assertThat (runner .invocations .get (2 ), allOf (hasFilter ("emp_no" ), hasSample (1e-14 )));
299- assertThat (runner .invocations .get (3 ), allOf (hasFilter ("emp_no" ), hasSample (1e-10 )));
300- assertThat (runner .invocations .get (4 ), allOf (hasFilter ("emp_no" ), hasSample (1e-6 )));
301- assertThat (runner .invocations .get (5 ), allOf (hasFilter ("emp_no" ), hasSample (1e-2 )));
302- assertThat (runner .invocations .get (6 ), allOf (hasFilter ("emp_no" ), not (hasSample ())));
303- assertThat (runner .invocations .get (7 ), allOf (hasFilter ("emp_no" ), not (hasSample ())));
298+ assertThat (runner .invocations .get (0 ), allOf (hasFilter ("emp_no" ), not (hasSample ()), hasSum () ));
299+ assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("emp_no" )), not (hasSample ()), not ( hasSum ()) ));
300+ assertThat (runner .invocations .get (2 ), allOf (hasFilter ("emp_no" ), hasSample (1e-14 ), not ( hasSum ()) ));
301+ assertThat (runner .invocations .get (3 ), allOf (hasFilter ("emp_no" ), hasSample (1e-10 ), not ( hasSum ()) ));
302+ assertThat (runner .invocations .get (4 ), allOf (hasFilter ("emp_no" ), hasSample (1e-6 ), not ( hasSum ()) ));
303+ assertThat (runner .invocations .get (5 ), allOf (hasFilter ("emp_no" ), hasSample (1e-2 ), not ( hasSum ()) ));
304+ assertThat (runner .invocations .get (6 ), allOf (hasFilter ("emp_no" ), not (hasSample ()), not ( hasSum ()) ));
305+ assertThat (runner .invocations .get (7 ), allOf (hasFilter ("emp_no" ), not (hasSample ()), hasSum () ));
304306 }
305307
306308 public void testCountPlan_smallDataBeforeFiltering () throws Exception {
@@ -312,9 +314,9 @@ public void testCountPlan_smallDataBeforeFiltering() throws Exception {
312314 // - one pass to get the total number of rows (which is small)
313315 // - one pass to execute the original query
314316 assertThat (runner .invocations , hasSize (3 ));
315- assertThat (runner .invocations .get (0 ), allOf (hasFilter ("gender" ), not (hasSample ())));
316- assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("gender" )), not (hasSample ())));
317- assertThat (runner .invocations .get (2 ), allOf (hasFilter ("gender" ), not (hasSample ())));
317+ assertThat (runner .invocations .get (0 ), allOf (hasFilter ("gender" ), not (hasSample ()), hasSum () ));
318+ assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ("gender" )), not (hasSample ()), not ( hasSum ()) ));
319+ assertThat (runner .invocations .get (2 ), allOf (hasFilter ("gender" ), not (hasSample ()), hasSum () ));
318320 }
319321
320322 public void testCountPlan_smallDataAfterMvExpanding () throws Exception {
@@ -327,10 +329,10 @@ public void testCountPlan_smallDataAfterMvExpanding() throws Exception {
327329 // - one pass to get the number of expanded rows (which determines the sample probability)
328330 // - one pass to execute the original query
329331 assertThat (runner .invocations , hasSize (4 ));
330- assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ())));
331- assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ())));
332- assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ())));
333- assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ())));
332+ assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ()), hasSum () ));
333+ assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ()), not ( hasSum ()) ));
334+ assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ()), not ( hasSum ()) ));
335+ assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ()), hasSum () ));
334336 }
335337
336338 public void testCountPlan_largeDataAfterMvExpanding () throws Exception {
@@ -343,10 +345,10 @@ public void testCountPlan_largeDataAfterMvExpanding() throws Exception {
343345 // - one pass to get the number of expanded rows (which determines the sample probability)
344346 // - one pass to approximate the query
345347 assertThat (runner .invocations , hasSize (4 ));
346- assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ())));
347- assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ())));
348- assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ())));
349- assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-4 )));
348+ assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ()), hasSum () ));
349+ assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ()), not ( hasSum ()), not ( hasSum ()) ));
350+ assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), not (hasSample ()), not ( hasSum ()), not ( hasSum ()) ));
351+ assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-4 ), hasSum () ));
350352 }
351353
352354 public void testCountPlan_largeDataBeforeMvExpanding () throws Exception {
@@ -359,10 +361,10 @@ public void testCountPlan_largeDataBeforeMvExpanding() throws Exception {
359361 // - one pass to sample the number of expanded rows (which determines the sample probability)
360362 // - one pass to approximate the query
361363 assertThat (runner .invocations , hasSize (4 ));
362- assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" )));
363- assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ())));
364- assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-5 )));
365- assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-7 )));
364+ assertThat (runner .invocations .get (0 ), allOf (hasMvExpand ("emp_no" ), hasSum () ));
365+ assertThat (runner .invocations .get (1 ), allOf (not (hasMvExpand ("emp_no" )), not (hasSample ()), not ( hasSum ()) ));
366+ assertThat (runner .invocations .get (2 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-5 ), not ( hasSum ()) ));
367+ assertThat (runner .invocations .get (3 ), allOf (hasMvExpand ("emp_no" ), hasSample (1e-7 ), hasSum () ));
366368 }
367369
368370 public void testCountPlan_sampleProbabilityThreshold_noFilter () throws Exception {
@@ -374,9 +376,9 @@ public void testCountPlan_sampleProbabilityThreshold_noFilter() throws Exception
374376 // - one pass to get the total number of rows
375377 // - one pass to execute the original query (because the sample probability is 20%)
376378 assertThat (runner .invocations , hasSize (3 ));
377- assertThat (runner .invocations .get (0 ), allOf (not (hasSample ())));
378- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ())));
379- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ())));
379+ assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasSum () ));
380+ assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not ( hasSum ()) ));
381+ assertThat (runner .invocations .get (2 ), allOf (not (hasSample ()), hasSum ( )));
380382 }
381383
382384 public void testCountPlan_sampleProbabilityThreshold_withFilter () throws Exception {
@@ -389,12 +391,12 @@ public void testCountPlan_sampleProbabilityThreshold_withFilter() throws Excepti
389391 // - two passes to get the number of filtered rows (which determines the sample probability)
390392 // - one pass to execute the original query (because the sample probability is 50%)
391393 assertThat (runner .invocations , hasSize (6 ));
392- assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasFilter ("emp_no" )));
393- assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not (hasFilter ("emp_no" ))));
394- assertThat (runner .invocations .get (2 ), allOf (hasSample (1e-8 ), hasFilter ("emp_no" )));
395- assertThat (runner .invocations .get (3 ), allOf (hasSample (1e-4 ), hasFilter ("emp_no" )));
396- assertThat (runner .invocations .get (4 ), allOf (hasSample (0.05 ), hasFilter ("emp_no" )));
397- assertThat (runner .invocations .get (5 ), allOf (not (hasSample ()), hasFilter ("emp_no" )));
394+ assertThat (runner .invocations .get (0 ), allOf (not (hasSample ()), hasFilter ("emp_no" ), hasSum () ));
395+ assertThat (runner .invocations .get (1 ), allOf (not (hasSample ()), not (hasFilter ("emp_no" )), not ( hasSum ()) ));
396+ assertThat (runner .invocations .get (2 ), allOf (hasSample (1e-8 ), hasFilter ("emp_no" ), not ( hasSum ()) ));
397+ assertThat (runner .invocations .get (3 ), allOf (hasSample (1e-4 ), hasFilter ("emp_no" ), not ( hasSum ()) ));
398+ assertThat (runner .invocations .get (4 ), allOf (hasSample (0.05 ), hasFilter ("emp_no" ), not ( hasSum ()) ));
399+ assertThat (runner .invocations .get (5 ), allOf (not (hasSample ()), hasFilter ("emp_no" ), hasSum () ));
398400 }
399401
400402 public void testApproximatePlan_createsConfidenceInterval_withoutGrouping () throws Exception {
@@ -486,6 +488,10 @@ private Matcher<? super LogicalPlan> hasSample(Double probability) {
486488 return hasPlan (Sample .class , sample -> sample .probability ().equals (Literal .fromDouble (Source .EMPTY , probability )));
487489 }
488490
491+ private Matcher <? super LogicalPlan > hasSum () {
492+ return hasPlan (Aggregate .class , aggr -> aggr .aggregates ().stream ().anyMatch (named -> named .anyMatch (expr -> expr instanceof Sum )));
493+ }
494+
489495 private <E extends LogicalPlan > Matcher <? super LogicalPlan > hasPlan (Class <E > typeToken , Predicate <? super E > predicate ) {
490496 return new TypeSafeMatcher <>() {
491497 @ Override
0 commit comments