Skip to content

Commit 6ea91fb

Browse files
committed
improve testing
1 parent e305d70 commit 6ea91fb

File tree

1 file changed

+52
-46
lines changed

1 file changed

+52
-46
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximate/ApproximateTests.java

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
2929
import org.elasticsearch.xpack.esql.core.tree.Source;
3030
import org.elasticsearch.xpack.esql.expression.Foldables;
31+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
3132
import org.elasticsearch.xpack.esql.inference.InferenceService;
3233
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
3334
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
3435
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer;
3536
import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext;
3637
import org.elasticsearch.xpack.esql.parser.EsqlParser;
3738
import org.elasticsearch.xpack.esql.parser.QueryParams;
39+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
3840
import org.elasticsearch.xpack.esql.plan.logical.Eval;
3941
import org.elasticsearch.xpack.esql.plan.logical.Filter;
4042
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -233,9 +235,9 @@ public void testCountPlan_noData() throws Exception {
233235
// - one pass to get the total number of rows (which is zero)
234236
// - one pass to execute the original query
235237
assertThat(runner.invocations, hasSize(3));
236-
assertThat(runner.invocations.get(0), allOf(not(hasSample())));
237-
assertThat(runner.invocations.get(1), allOf(not(hasSample())));
238-
assertThat(runner.invocations.get(2), allOf(not(hasSample())));
238+
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasSum()));
239+
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasSum())));
240+
assertThat(runner.invocations.get(2), allOf(not(hasSample()), hasSum()));
239241
}
240242

241243
public void testCountPlan_largeDataNoFilters() throws Exception {
@@ -247,9 +249,9 @@ public void testCountPlan_largeDataNoFilters() throws Exception {
247249
// - one pass to get the total number of rows (which determines the sample probability)
248250
// - one pass to approximate the query
249251
assertThat(runner.invocations, hasSize(3));
250-
assertThat(runner.invocations.get(0), allOf(not(hasSample())));
251-
assertThat(runner.invocations.get(1), allOf(not(hasSample())));
252-
assertThat(runner.invocations.get(2), allOf(hasSample(1e-4)));
252+
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasSum()));
253+
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasSum())));
254+
assertThat(runner.invocations.get(2), allOf(hasSample(1e-4), hasSum()));
253255
}
254256

255257
public void testCountPlan_smallDataNoFilters() throws Exception {
@@ -261,9 +263,9 @@ public void testCountPlan_smallDataNoFilters() throws Exception {
261263
// - one pass to get the total number of rows (which is small)
262264
// - one pass to execute the original query
263265
assertThat(runner.invocations, hasSize(3));
264-
assertThat(runner.invocations.get(0), allOf(not(hasSample())));
265-
assertThat(runner.invocations.get(1), allOf(not(hasSample())));
266-
assertThat(runner.invocations.get(2), allOf(not(hasSample())));
266+
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasSum()));
267+
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasSum())));
268+
assertThat(runner.invocations.get(2), allOf(not(hasSample()), hasSum()));
267269
}
268270

269271
public void testCountPlan_largeDataAfterFiltering() throws Exception {
@@ -276,11 +278,11 @@ public void testCountPlan_largeDataAfterFiltering() throws Exception {
276278
// - two passes to get the number of filtered rows (which determines the sample probability)
277279
// - one pass to approximate the query
278280
assertThat(runner.invocations, hasSize(5));
279-
assertThat(runner.invocations.get(0), allOf(hasFilter("emp_no"), not(hasSample())));
280-
assertThat(runner.invocations.get(1), allOf(not(hasFilter("emp_no")), not(hasSample())));
281-
assertThat(runner.invocations.get(2), allOf(hasFilter("emp_no"), hasSample(1e-8)));
282-
assertThat(runner.invocations.get(3), allOf(hasFilter("emp_no"), hasSample(1e-5)));
283-
assertThat(runner.invocations.get(4), allOf(hasFilter("emp_no"), hasSample(1e-4)));
281+
assertThat(runner.invocations.get(0), allOf(hasFilter("emp_no"), not(hasSample()), hasSum()));
282+
assertThat(runner.invocations.get(1), allOf(not(hasFilter("emp_no")), not(hasSample()), not(hasSum())));
283+
assertThat(runner.invocations.get(2), allOf(hasFilter("emp_no"), hasSample(1e-8), not(hasSum())));
284+
assertThat(runner.invocations.get(3), allOf(hasFilter("emp_no"), hasSample(1e-5), not(hasSum())));
285+
assertThat(runner.invocations.get(4), allOf(hasFilter("emp_no"), hasSample(1e-4), hasSum()));
284286
}
285287

286288
public void testCountPlan_smallDataAfterFiltering() throws Exception {
@@ -293,14 +295,14 @@ public void testCountPlan_smallDataAfterFiltering() throws Exception {
293295
// - three passes to get the number of filtered rows (which is small)
294296
// - one pass to execute the original query
295297
assertThat(runner.invocations, hasSize(8));
296-
assertThat(runner.invocations.get(0), allOf(hasFilter("emp_no"), not(hasSample())));
297-
assertThat(runner.invocations.get(1), allOf(not(hasFilter("emp_no")), not(hasSample())));
298-
assertThat(runner.invocations.get(2), allOf(hasFilter("emp_no"), hasSample(1e-14)));
299-
assertThat(runner.invocations.get(3), allOf(hasFilter("emp_no"), hasSample(1e-10)));
300-
assertThat(runner.invocations.get(4), allOf(hasFilter("emp_no"), hasSample(1e-6)));
301-
assertThat(runner.invocations.get(5), allOf(hasFilter("emp_no"), hasSample(1e-2)));
302-
assertThat(runner.invocations.get(6), allOf(hasFilter("emp_no"), not(hasSample())));
303-
assertThat(runner.invocations.get(7), allOf(hasFilter("emp_no"), not(hasSample())));
298+
assertThat(runner.invocations.get(0), allOf(hasFilter("emp_no"), not(hasSample()), hasSum()));
299+
assertThat(runner.invocations.get(1), allOf(not(hasFilter("emp_no")), not(hasSample()), not(hasSum())));
300+
assertThat(runner.invocations.get(2), allOf(hasFilter("emp_no"), hasSample(1e-14), not(hasSum())));
301+
assertThat(runner.invocations.get(3), allOf(hasFilter("emp_no"), hasSample(1e-10), not(hasSum())));
302+
assertThat(runner.invocations.get(4), allOf(hasFilter("emp_no"), hasSample(1e-6), not(hasSum())));
303+
assertThat(runner.invocations.get(5), allOf(hasFilter("emp_no"), hasSample(1e-2), not(hasSum())));
304+
assertThat(runner.invocations.get(6), allOf(hasFilter("emp_no"), not(hasSample()), not(hasSum())));
305+
assertThat(runner.invocations.get(7), allOf(hasFilter("emp_no"), not(hasSample()), hasSum()));
304306
}
305307

306308
public void testCountPlan_smallDataBeforeFiltering() throws Exception {
@@ -312,9 +314,9 @@ public void testCountPlan_smallDataBeforeFiltering() throws Exception {
312314
// - one pass to get the total number of rows (which is small)
313315
// - one pass to execute the original query
314316
assertThat(runner.invocations, hasSize(3));
315-
assertThat(runner.invocations.get(0), allOf(hasFilter("gender"), not(hasSample())));
316-
assertThat(runner.invocations.get(1), allOf(not(hasFilter("gender")), not(hasSample())));
317-
assertThat(runner.invocations.get(2), allOf(hasFilter("gender"), not(hasSample())));
317+
assertThat(runner.invocations.get(0), allOf(hasFilter("gender"), not(hasSample()), hasSum()));
318+
assertThat(runner.invocations.get(1), allOf(not(hasFilter("gender")), not(hasSample()), not(hasSum())));
319+
assertThat(runner.invocations.get(2), allOf(hasFilter("gender"), not(hasSample()), hasSum()));
318320
}
319321

320322
public void testCountPlan_smallDataAfterMvExpanding() throws Exception {
@@ -327,10 +329,10 @@ public void testCountPlan_smallDataAfterMvExpanding() throws Exception {
327329
// - one pass to get the number of expanded rows (which determines the sample probability)
328330
// - one pass to execute the original query
329331
assertThat(runner.invocations, hasSize(4));
330-
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no"), not(hasSample())));
331-
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample())));
332-
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), not(hasSample())));
333-
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), not(hasSample())));
332+
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no"), not(hasSample()), hasSum()));
333+
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample()), not(hasSum())));
334+
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), not(hasSample()), not(hasSum())));
335+
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), not(hasSample()), hasSum()));
334336
}
335337

336338
public void testCountPlan_largeDataAfterMvExpanding() throws Exception {
@@ -343,10 +345,10 @@ public void testCountPlan_largeDataAfterMvExpanding() throws Exception {
343345
// - one pass to get the number of expanded rows (which determines the sample probability)
344346
// - one pass to approximate the query
345347
assertThat(runner.invocations, hasSize(4));
346-
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no"), not(hasSample())));
347-
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample())));
348-
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), not(hasSample())));
349-
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), hasSample(1e-4)));
348+
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no"), not(hasSample()), hasSum()));
349+
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample()), not(hasSum()), not(hasSum())));
350+
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), not(hasSample()), not(hasSum()), not(hasSum())));
351+
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), hasSample(1e-4), hasSum()));
350352
}
351353

352354
public void testCountPlan_largeDataBeforeMvExpanding() throws Exception {
@@ -359,10 +361,10 @@ public void testCountPlan_largeDataBeforeMvExpanding() throws Exception {
359361
// - one pass to sample the number of expanded rows (which determines the sample probability)
360362
// - one pass to approximate the query
361363
assertThat(runner.invocations, hasSize(4));
362-
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no")));
363-
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample())));
364-
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), hasSample(1e-5)));
365-
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), hasSample(1e-7)));
364+
assertThat(runner.invocations.get(0), allOf(hasMvExpand("emp_no"), hasSum()));
365+
assertThat(runner.invocations.get(1), allOf(not(hasMvExpand("emp_no")), not(hasSample()), not(hasSum())));
366+
assertThat(runner.invocations.get(2), allOf(hasMvExpand("emp_no"), hasSample(1e-5), not(hasSum())));
367+
assertThat(runner.invocations.get(3), allOf(hasMvExpand("emp_no"), hasSample(1e-7), hasSum()));
366368
}
367369

368370
public void testCountPlan_sampleProbabilityThreshold_noFilter() throws Exception {
@@ -374,9 +376,9 @@ public void testCountPlan_sampleProbabilityThreshold_noFilter() throws Exception
374376
// - one pass to get the total number of rows
375377
// - one pass to execute the original query (because the sample probability is 20%)
376378
assertThat(runner.invocations, hasSize(3));
377-
assertThat(runner.invocations.get(0), allOf(not(hasSample())));
378-
assertThat(runner.invocations.get(1), allOf(not(hasSample())));
379-
assertThat(runner.invocations.get(1), allOf(not(hasSample())));
379+
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasSum()));
380+
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasSum())));
381+
assertThat(runner.invocations.get(2), allOf(not(hasSample()), hasSum()));
380382
}
381383

382384
public void testCountPlan_sampleProbabilityThreshold_withFilter() throws Exception {
@@ -389,12 +391,12 @@ public void testCountPlan_sampleProbabilityThreshold_withFilter() throws Excepti
389391
// - two passes to get the number of filtered rows (which determines the sample probability)
390392
// - one pass to execute the original query (because the sample probability is 50%)
391393
assertThat(runner.invocations, hasSize(6));
392-
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasFilter("emp_no")));
393-
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasFilter("emp_no"))));
394-
assertThat(runner.invocations.get(2), allOf(hasSample(1e-8), hasFilter("emp_no")));
395-
assertThat(runner.invocations.get(3), allOf(hasSample(1e-4), hasFilter("emp_no")));
396-
assertThat(runner.invocations.get(4), allOf(hasSample(0.05), hasFilter("emp_no")));
397-
assertThat(runner.invocations.get(5), allOf(not(hasSample()), hasFilter("emp_no")));
394+
assertThat(runner.invocations.get(0), allOf(not(hasSample()), hasFilter("emp_no"), hasSum()));
395+
assertThat(runner.invocations.get(1), allOf(not(hasSample()), not(hasFilter("emp_no")), not(hasSum())));
396+
assertThat(runner.invocations.get(2), allOf(hasSample(1e-8), hasFilter("emp_no"), not(hasSum())));
397+
assertThat(runner.invocations.get(3), allOf(hasSample(1e-4), hasFilter("emp_no"), not(hasSum())));
398+
assertThat(runner.invocations.get(4), allOf(hasSample(0.05), hasFilter("emp_no"), not(hasSum())));
399+
assertThat(runner.invocations.get(5), allOf(not(hasSample()), hasFilter("emp_no"), hasSum()));
398400
}
399401

400402
public void testApproximatePlan_createsConfidenceInterval_withoutGrouping() throws Exception {
@@ -486,6 +488,10 @@ private Matcher<? super LogicalPlan> hasSample(Double probability) {
486488
return hasPlan(Sample.class, sample -> sample.probability().equals(Literal.fromDouble(Source.EMPTY, probability)));
487489
}
488490

491+
private Matcher<? super LogicalPlan> hasSum() {
492+
return hasPlan(Aggregate.class, aggr -> aggr.aggregates().stream().anyMatch(named -> named.anyMatch(expr -> expr instanceof Sum)));
493+
}
494+
489495
private <E extends LogicalPlan> Matcher<? super LogicalPlan> hasPlan(Class<E> typeToken, Predicate<? super E> predicate) {
490496
return new TypeSafeMatcher<>() {
491497
@Override

0 commit comments

Comments
 (0)