9
9
10
10
import org .apache .lucene .util .SetOnce ;
11
11
import org .elasticsearch .action .ActionListener ;
12
+ import org .elasticsearch .common .breaker .CircuitBreaker ;
13
+ import org .elasticsearch .common .unit .ByteSizeValue ;
14
+ import org .elasticsearch .common .util .BigArrays ;
15
+ import org .elasticsearch .common .util .MockBigArrays ;
16
+ import org .elasticsearch .common .util .PageCacheRecycler ;
17
+ import org .elasticsearch .compute .data .LongBlock ;
18
+ import org .elasticsearch .compute .data .Page ;
19
+ import org .elasticsearch .compute .test .MockBlockFactory ;
12
20
import org .elasticsearch .test .ESTestCase ;
13
21
import org .elasticsearch .xpack .esql .VerificationException ;
14
22
import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
23
+ import org .elasticsearch .xpack .esql .core .expression .Literal ;
15
24
import org .elasticsearch .xpack .esql .optimizer .LogicalPlanPreOptimizer ;
16
25
import org .elasticsearch .xpack .esql .optimizer .LogicalPreOptimizerContext ;
17
26
import org .elasticsearch .xpack .esql .parser .EsqlParser ;
27
+ import org .elasticsearch .xpack .esql .plan .logical .Filter ;
18
28
import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
29
+ import org .elasticsearch .xpack .esql .plan .logical .Sample ;
30
+ import org .elasticsearch .xpack .esql .session .Result ;
31
+ import org .hamcrest .Description ;
19
32
import org .hamcrest .Matcher ;
33
+ import org .hamcrest .TypeSafeMatcher ;
34
+
35
+ import java .util .ArrayList ;
36
+ import java .util .List ;
20
37
21
38
import static org .elasticsearch .xpack .esql .EsqlTestUtils .TEST_CFG ;
39
+ import static org .hamcrest .CoreMatchers .allOf ;
40
+ import static org .hamcrest .CoreMatchers .not ;
22
41
import static org .hamcrest .Matchers .equalTo ;
42
+ import static org .hamcrest .Matchers .hasSize ;
23
43
24
44
public class ApproximateTests extends ESTestCase {
25
45
26
46
private static final EsqlParser parser = new EsqlParser ();
27
47
private static final LogicalPlanPreOptimizer logicalPlanPreOptimizer = new LogicalPlanPreOptimizer (
28
48
new LogicalPreOptimizerContext (FoldContext .small ())
29
49
);
50
+ private static final CircuitBreaker breaker = newLimitedBreaker (ByteSizeValue .ofGb (1 ));
51
+ private static final BigArrays bigArrays = new MockBigArrays (PageCacheRecycler .NON_RECYCLING_INSTANCE , ByteSizeValue .ofGb (1 ));
52
+ private static final MockBlockFactory blockFactory = new MockBlockFactory (breaker , bigArrays );
30
53
31
54
public void testVerify_validQuery () throws Exception {
32
- verifyQuery ("FROM index | EVAL x=1 | WHERE y<0.1 | SORT z | SAMPLE 0.1 | STATS COUNT() BY y | SORT z | MV_EXPAND x" );
55
+ createApproximate ("FROM index | EVAL x=1 | WHERE y<0.1 | SORT z | SAMPLE 0.1 | STATS COUNT() BY y | SORT z | MV_EXPAND x" );
33
56
}
34
57
35
58
public void testVerify_noStats () {
@@ -80,12 +103,153 @@ public void testVerify_nonSwappableCommand() {
80
103
);
81
104
}
82
105
106
+ /**
107
+ * Runner that simulates the execution of an ESQL query.
108
+ *
109
+ * The runner always returns a result with one field: the number of rows.
110
+ *
111
+ * The runner is initialized with a total number of rows (returned when
112
+ * there are no filters in the query), and a number of filtered rows
113
+ * (returned when there are filters in the query). If there's random
114
+ * sampling in the query, the returned number of rows is multiplied by
115
+ * the sampling probability.
116
+ *
117
+ * The runner collects all its invocations.
118
+ */
119
+ private static class TestRunner implements Approximate .LogicalPlanRunner {
120
+
121
+ private final long totalRows ;
122
+ private final long filteredRows ;
123
+ private final List <LogicalPlan > invocations ;
124
+
125
+ TestRunner (long totalRows , long filteredRows ) {
126
+ this .totalRows = totalRows ;
127
+ this .filteredRows = filteredRows ;
128
+ this .invocations = new ArrayList <>();
129
+ }
130
+
131
+ @ Override
132
+ public void run (LogicalPlan logicalPlan , ActionListener <Result > listener ) {
133
+ invocations .add (logicalPlan );
134
+ List <LogicalPlan > filter = logicalPlan .collect (plan -> plan instanceof Filter );
135
+ long numResults = filter .isEmpty () ? totalRows : filteredRows ;
136
+ List <LogicalPlan > sample = logicalPlan .collect (plan -> plan instanceof Sample );
137
+ if (sample .isEmpty () == false ) {
138
+ numResults = (long ) (numResults * (double ) ((Literal ) ((Sample ) sample .getFirst ()).probability ()).value ());
139
+ }
140
+ LongBlock block = blockFactory .newConstantLongBlockWith (numResults , 1 );
141
+ listener .onResponse (new Result (null , List .of (new Page (block )), null , null ));
142
+ }
143
+ }
144
+
145
+ public void testApproximate_largeDataNoFilters () throws Exception {
146
+ Approximate approximate = createApproximate ("FROM index | STATS SUM(x), AVG(x)" );
147
+ TestRunner runner = new TestRunner (1_000_000_000 , 1_000_000_000 );
148
+ approximate .approximate (runner , ActionListener .noop ());
149
+ // One pass is needed to get the number of rows, and approximation is executed immediately
150
+ // after that with the correct sample probability.
151
+ assertThat (runner .invocations , hasSize (2 ));
152
+ assertThat (runner .invocations .get (0 ), allOf (not (hasFilter ()), not (hasSample ())));
153
+ assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ()), hasSample (1e-5 )));
154
+ System .out .println (runner .invocations );
155
+ }
156
+
157
+ public void testApproximate_smallDataNoFilters () throws Exception {
158
+ Approximate approximate = createApproximate ("FROM index | STATS SUM(x), AVG(x)" );
159
+ TestRunner runner = new TestRunner (1_000 , 1_000 );
160
+ approximate .approximate (runner , ActionListener .noop ());
161
+ // One pass is needed to get the number of rows, and the original query is executed
162
+ // immediately after that without sampling.
163
+ assertThat (runner .invocations , hasSize (2 ));
164
+ assertThat (runner .invocations .get (0 ), allOf (not (hasFilter ()), not (hasSample ())));
165
+ assertThat (runner .invocations .get (1 ), allOf (not (hasFilter ()), not (hasSample ())));
166
+ }
167
+
168
+ public void testApproximate_largeDataAfterFiltering () throws Exception {
169
+ Approximate approximate = createApproximate ("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)" );
170
+ TestRunner runner = new TestRunner (1_000_000_000_000L , 1_000_000_000 );
171
+ approximate .approximate (runner , ActionListener .noop ());
172
+ // One pass is needed to get the number of rows, then a few passes to get a good sample
173
+ // probability, and finally approximation is executed.
174
+ assertThat (runner .invocations , hasSize (4 ));
175
+ assertThat (runner .invocations .get (0 ), allOf (not (hasFilter ()), not (hasSample ())));
176
+ assertThat (runner .invocations .get (1 ), allOf (hasFilter (), hasSample (1e-8 )));
177
+ assertThat (runner .invocations .get (2 ), allOf (hasFilter (), hasSample (1e-5 )));
178
+ assertThat (runner .invocations .get (3 ), allOf (hasFilter (), hasSample (1e-5 )));
179
+ }
180
+
181
+ public void testApproximate_smallDataAfterFiltering () throws Exception {
182
+ Approximate approximate = createApproximate ("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)" );
183
+ TestRunner runner = new TestRunner (1_000_000_000_000_000_000L , 100 );
184
+ approximate .approximate (runner , ActionListener .noop ());
185
+ // One pass is needed to get the number of rows, then a few passes to get a good sample
186
+ // probability, and finally the original query is executed without sampling.
187
+ assertThat (runner .invocations , hasSize (6 ));
188
+ assertThat (runner .invocations .get (0 ), allOf (not (hasFilter ()), not (hasSample ())));
189
+ assertThat (runner .invocations .get (1 ), allOf (hasFilter (), hasSample (1e-14 )));
190
+ assertThat (runner .invocations .get (2 ), allOf (hasFilter (), hasSample (1e-10 )));
191
+ assertThat (runner .invocations .get (3 ), allOf (hasFilter (), hasSample (1e-6 )));
192
+ assertThat (runner .invocations .get (4 ), allOf (hasFilter (), hasSample (1e-2 )));
193
+ assertThat (runner .invocations .get (5 ), allOf (hasFilter (), not (hasSample ())));
194
+ }
195
+
196
+ public void testApproximate_smallDataBeforeFiltering () throws Exception {
197
+ Approximate approximate = createApproximate ("FROM index | WHERE t < 1 | STATS SUM(x), AVG(x)" );
198
+ TestRunner runner = new TestRunner (1_000 , 10 );
199
+ approximate .approximate (runner , ActionListener .noop ());
200
+ // One pass is needed to get the number of rows, and the original query is executed
201
+ // immediately after that without sampling.
202
+ assertThat (runner .invocations , hasSize (2 ));
203
+ assertThat (runner .invocations .get (0 ), allOf (not (hasFilter ()), not (hasSample ())));
204
+ assertThat (runner .invocations .get (1 ), allOf (hasFilter (), not (hasSample ())));
205
+ }
206
+
207
+ private Matcher <? super LogicalPlan > hasFilter () {
208
+ return new TypeSafeMatcher <>() {
209
+
210
+ @ Override
211
+ protected boolean matchesSafely (LogicalPlan logicalPlan ) {
212
+ return logicalPlan .anyMatch (plan -> plan instanceof Filter );
213
+ }
214
+
215
+ @ Override
216
+ public void describeTo (Description description ) {
217
+ description .appendText ("a plan containing a Filter" );
218
+ }
219
+ };
220
+ }
221
+
222
+ private Matcher <? super LogicalPlan > hasSample () {
223
+ return hasSample (null );
224
+ }
225
+
226
+ private Matcher <? super LogicalPlan > hasSample (Double probability ) {
227
+ return new TypeSafeMatcher <>() {
228
+
229
+ @ Override
230
+ protected boolean matchesSafely (LogicalPlan logicalPlan ) {
231
+ return logicalPlan .anyMatch (
232
+ plan -> plan instanceof Sample
233
+ && (probability == null || ((Literal ) ((Sample ) plan ).probability ()).value ().equals (probability ))
234
+ );
235
+ }
236
+
237
+ @ Override
238
+ public void describeTo (Description description ) {
239
+ description .appendText ("a plan containing a Sample" );
240
+ if (probability != null ) {
241
+ description .appendText (" with probability " + probability );
242
+ }
243
+ }
244
+ };
245
+ }
246
+
83
247
private void assertError (String esql , Matcher <String > matcher ) {
84
- Exception e = assertThrows (VerificationException .class , () -> verifyQuery (esql ));
248
+ Exception e = assertThrows (VerificationException .class , () -> createApproximate (esql ));
85
249
assertThat (e .getMessage ().substring ("Found 1 problem\n " .length ()), matcher );
86
250
}
87
251
88
- private void verifyQuery (String esql ) throws Exception {
252
+ private Approximate createApproximate (String esql ) throws Exception {
89
253
SetOnce <LogicalPlan > resultHolder = new SetOnce <>();
90
254
SetOnce <Exception > exceptionHolder = new SetOnce <>();
91
255
LogicalPlan plan = parser .createStatement (esql , TEST_CFG );
@@ -94,6 +258,6 @@ private void verifyQuery(String esql) throws Exception {
94
258
if (exceptionHolder .get () != null ) {
95
259
throw exceptionHolder .get ();
96
260
}
97
- new Approximate (resultHolder .get ());
261
+ return new Approximate (resultHolder .get ());
98
262
}
99
263
}
0 commit comments