Skip to content

Commit 893b0f8

Browse files
committed
test row sampling behavior
1 parent c308c5a commit 893b0f8

File tree

2 files changed

+171
-7
lines changed

2 files changed

+171
-7
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ private ActionListener<Result> sourceCountListener(LogicalPlanRunner runner, Act
199199
logger.debug("sourceCountPlan result: {} rows", rowCount(countResult));
200200
double sampleProbability = sampleProbability(countResult);
201201
countResult.pages().getFirst().close();
202-
if (hasFilters) {
203-
runner.run(countPlan(sampleProbability), countListener(runner, sampleProbability, listener));
204-
} else {
202+
if (hasFilters == false || sampleProbability == 1.0) {
205203
runner.run(approximatePlan(sampleProbability), listener);
204+
} else {
205+
runner.run(countPlan(sampleProbability), countListener(runner, sampleProbability, listener));
206206
}
207207
});
208208
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/approximate/ApproximateTests.java

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,50 @@
99

1010
import org.apache.lucene.util.SetOnce;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.common.breaker.CircuitBreaker;
13+
import org.elasticsearch.common.unit.ByteSizeValue;
14+
import org.elasticsearch.common.util.BigArrays;
15+
import org.elasticsearch.common.util.MockBigArrays;
16+
import org.elasticsearch.common.util.PageCacheRecycler;
17+
import org.elasticsearch.compute.data.LongBlock;
18+
import org.elasticsearch.compute.data.Page;
19+
import org.elasticsearch.compute.test.MockBlockFactory;
1220
import org.elasticsearch.test.ESTestCase;
1321
import org.elasticsearch.xpack.esql.VerificationException;
1422
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
23+
import org.elasticsearch.xpack.esql.core.expression.Literal;
1524
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer;
1625
import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext;
1726
import org.elasticsearch.xpack.esql.parser.EsqlParser;
27+
import org.elasticsearch.xpack.esql.plan.logical.Filter;
1828
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
29+
import org.elasticsearch.xpack.esql.plan.logical.Sample;
30+
import org.elasticsearch.xpack.esql.session.Result;
31+
import org.hamcrest.Description;
1932
import org.hamcrest.Matcher;
33+
import org.hamcrest.TypeSafeMatcher;
34+
35+
import java.util.ArrayList;
36+
import java.util.List;
2037

2138
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
39+
import static org.hamcrest.CoreMatchers.allOf;
40+
import static org.hamcrest.CoreMatchers.not;
2241
import static org.hamcrest.Matchers.equalTo;
42+
import static org.hamcrest.Matchers.hasSize;
2343

2444
public class ApproximateTests extends ESTestCase {
2545

2646
private static final EsqlParser parser = new EsqlParser();
2747
private static final LogicalPlanPreOptimizer logicalPlanPreOptimizer = new LogicalPlanPreOptimizer(
2848
new LogicalPreOptimizerContext(FoldContext.small())
2949
);
50+
private static final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
51+
private static final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofGb(1));
52+
private static final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
3053

3154
public void testVerify_validQuery() throws Exception {
32-
verifyQuery("FROM index | EVAL x=1 | WHERE y<0.1 | SORT z | SAMPLE 0.1 | STATS COUNT() BY y | SORT z | MV_EXPAND x");
55+
createApproximate("FROM index | EVAL x=1 | WHERE y<0.1 | SORT z | SAMPLE 0.1 | STATS COUNT() BY y | SORT z | MV_EXPAND x");
3356
}
3457

3558
public void testVerify_noStats() {
@@ -80,12 +103,153 @@ public void testVerify_nonSwappableCommand() {
80103
);
81104
}
82105

106+
/**
107+
* Runner that simulates the execution of an ESQL query.
108+
*
109+
* The runner always returns a result with one field: the number of rows.
110+
*
111+
* The runner is initialized with a total number of rows (returned when
112+
* there are no filters in the query), and a number of filtered rows
113+
* (returned when there are filters in the query). If there's random
114+
* sampling in the query, the returned number of rows is multiplied by
115+
* the sampling probability.
116+
*
117+
* The runner collects all its invocations.
118+
*/
119+
private static class TestRunner implements Approximate.LogicalPlanRunner {
120+
121+
private final long totalRows;
122+
private final long filteredRows;
123+
private final List<LogicalPlan> invocations;
124+
125+
TestRunner(long totalRows, long filteredRows) {
126+
this.totalRows = totalRows;
127+
this.filteredRows = filteredRows;
128+
this.invocations = new ArrayList<>();
129+
}
130+
131+
@Override
132+
public void run(LogicalPlan logicalPlan, ActionListener<Result> listener) {
133+
invocations.add(logicalPlan);
134+
List<LogicalPlan> filter = logicalPlan.collect(plan -> plan instanceof Filter);
135+
long numResults = filter.isEmpty() ? totalRows : filteredRows;
136+
List<LogicalPlan> sample = logicalPlan.collect(plan -> plan instanceof Sample);
137+
if (sample.isEmpty() == false) {
138+
numResults = (long) (numResults * (double) ((Literal) ((Sample) sample.getFirst()).probability()).value());
139+
}
140+
LongBlock block = blockFactory.newConstantLongBlockWith(numResults, 1);
141+
listener.onResponse(new Result(null, List.of(new Page(block)), null, null));
142+
}
143+
}
144+
145+
public void testApproximate_largeDataNoFilters() throws Exception {
146+
Approximate approximate = createApproximate("FROM index | STATS SUM(x), AVG(x)");
147+
TestRunner runner = new TestRunner(1_000_000_000, 1_000_000_000);
148+
approximate.approximate(runner, ActionListener.noop());
149+
// One pass is needed to get the number of rows, and approximation is executed immediately
150+
// after that with the correct sample probability.
151+
assertThat(runner.invocations, hasSize(2));
152+
assertThat(runner.invocations.get(0), allOf(not(hasFilter()), not(hasSample())));
153+
assertThat(runner.invocations.get(1), allOf(not(hasFilter()), hasSample(1e-5)));
154+
System.out.println(runner.invocations);
155+
}
156+
157+
public void testApproximate_smallDataNoFilters() throws Exception {
158+
Approximate approximate = createApproximate("FROM index | STATS SUM(x), AVG(x)");
159+
TestRunner runner = new TestRunner(1_000, 1_000);
160+
approximate.approximate(runner, ActionListener.noop());
161+
// One pass is needed to get the number of rows, and the original query is executed
162+
// immediately after that without sampling.
163+
assertThat(runner.invocations, hasSize(2));
164+
assertThat(runner.invocations.get(0), allOf(not(hasFilter()), not(hasSample())));
165+
assertThat(runner.invocations.get(1), allOf(not(hasFilter()), not(hasSample())));
166+
}
167+
168+
public void testApproximate_largeDataAfterFiltering() throws Exception {
169+
Approximate approximate = createApproximate("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)");
170+
TestRunner runner = new TestRunner(1_000_000_000_000L, 1_000_000_000);
171+
approximate.approximate(runner, ActionListener.noop());
172+
// One pass is needed to get the number of rows, then a few passes to get a good sample
173+
// probability, and finally approximation is executed.
174+
assertThat(runner.invocations, hasSize(4));
175+
assertThat(runner.invocations.get(0), allOf(not(hasFilter()), not(hasSample())));
176+
assertThat(runner.invocations.get(1), allOf(hasFilter(), hasSample(1e-8)));
177+
assertThat(runner.invocations.get(2), allOf(hasFilter(), hasSample(1e-5)));
178+
assertThat(runner.invocations.get(3), allOf(hasFilter(), hasSample(1e-5)));
179+
}
180+
181+
public void testApproximate_smallDataAfterFiltering() throws Exception {
182+
Approximate approximate = createApproximate("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)");
183+
TestRunner runner = new TestRunner(1_000_000_000_000_000_000L, 100);
184+
approximate.approximate(runner, ActionListener.noop());
185+
// One pass is needed to get the number of rows, then a few passes to get a good sample
186+
// probability, and finally the original query is executed without sampling.
187+
assertThat(runner.invocations, hasSize(6));
188+
assertThat(runner.invocations.get(0), allOf(not(hasFilter()), not(hasSample())));
189+
assertThat(runner.invocations.get(1), allOf(hasFilter(), hasSample(1e-14)));
190+
assertThat(runner.invocations.get(2), allOf(hasFilter(), hasSample(1e-10)));
191+
assertThat(runner.invocations.get(3), allOf(hasFilter(), hasSample(1e-6)));
192+
assertThat(runner.invocations.get(4), allOf(hasFilter(), hasSample(1e-2)));
193+
assertThat(runner.invocations.get(5), allOf(hasFilter(), not(hasSample())));
194+
}
195+
196+
public void testApproximate_smallDataBeforeFiltering() throws Exception {
197+
Approximate approximate = createApproximate("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)");
198+
TestRunner runner = new TestRunner(1_000, 10);
199+
approximate.approximate(runner, ActionListener.noop());
200+
// One pass is needed to get the number of rows, and the original query is executed
201+
// immediately after that without sampling.
202+
assertThat(runner.invocations, hasSize(2));
203+
assertThat(runner.invocations.get(0), allOf(not(hasFilter()), not(hasSample())));
204+
assertThat(runner.invocations.get(1), allOf(hasFilter(), not(hasSample())));
205+
}
206+
207+
private Matcher<? super LogicalPlan> hasFilter() {
208+
return new TypeSafeMatcher<>() {
209+
210+
@Override
211+
protected boolean matchesSafely(LogicalPlan logicalPlan) {
212+
return logicalPlan.anyMatch(plan -> plan instanceof Filter);
213+
}
214+
215+
@Override
216+
public void describeTo(Description description) {
217+
description.appendText("a plan containing a Filter");
218+
}
219+
};
220+
}
221+
222+
private Matcher<? super LogicalPlan> hasSample() {
223+
return hasSample(null);
224+
}
225+
226+
private Matcher<? super LogicalPlan> hasSample(Double probability) {
227+
return new TypeSafeMatcher<>() {
228+
229+
@Override
230+
protected boolean matchesSafely(LogicalPlan logicalPlan) {
231+
return logicalPlan.anyMatch(
232+
plan -> plan instanceof Sample
233+
&& (probability == null || ((Literal) ((Sample) plan).probability()).value().equals(probability))
234+
);
235+
}
236+
237+
@Override
238+
public void describeTo(Description description) {
239+
description.appendText("a plan containing a Sample");
240+
if (probability != null) {
241+
description.appendText(" with probability " + probability);
242+
}
243+
}
244+
};
245+
}
246+
83247
private void assertError(String esql, Matcher<String> matcher) {
84-
Exception e = assertThrows(VerificationException.class, () -> verifyQuery(esql));
248+
Exception e = assertThrows(VerificationException.class, () -> createApproximate(esql));
85249
assertThat(e.getMessage().substring("Found 1 problem\n".length()), matcher);
86250
}
87251

88-
private void verifyQuery(String esql) throws Exception {
252+
private Approximate createApproximate(String esql) throws Exception {
89253
SetOnce<LogicalPlan> resultHolder = new SetOnce<>();
90254
SetOnce<Exception> exceptionHolder = new SetOnce<>();
91255
LogicalPlan plan = parser.createStatement(esql, TEST_CFG);
@@ -94,6 +258,6 @@ private void verifyQuery(String esql) throws Exception {
94258
if (exceptionHolder.get() != null) {
95259
throw exceptionHolder.get();
96260
}
97-
new Approximate(resultHolder.get());
261+
return new Approximate(resultHolder.get());
98262
}
99263
}

0 commit comments

Comments
 (0)