Skip to content

Commit 3ff71a5

Browse files
committed
Test accuracy of sampling operator
1 parent bd82576 commit 3ff71a5

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SampleOperatorTests.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import java.util.List;
1717
import java.util.stream.LongStream;
1818

19+
import static org.hamcrest.Matchers.both;
1920
import static org.hamcrest.Matchers.closeTo;
2021
import static org.hamcrest.Matchers.equalTo;
22+
import static org.hamcrest.Matchers.greaterThan;
23+
import static org.hamcrest.Matchers.lessThan;
2124
import static org.hamcrest.Matchers.matchesPattern;
2225

2326
public class SampleOperatorTests extends OperatorTestCase {
@@ -37,7 +40,7 @@ protected void assertSimpleOutput(List<Page> input, List<Page> results) {
3740
}
3841

3942
@Override
40-
protected Operator.OperatorFactory simple() {
43+
protected SampleOperator.Factory simple() {
4144
return new SampleOperator.Factory(0.5, randomInt());
4245
}
4346

@@ -50,4 +53,23 @@ protected Matcher<String> expectedDescriptionOfSimple() {
5053
protected Matcher<String> expectedToStringOfSimple() {
5154
return equalTo("SampleOperator[sampled = 0/0]");
5255
}
56+
57+
public void testAccuracy() {
58+
BlockFactory blockFactory = driverContext().blockFactory();
59+
int totalPositionCount = 0;
60+
61+
for (int iter = 0; iter < 10000; iter++) {
62+
SampleOperator operator = simple().get(driverContext());
63+
operator.addInput(new Page(blockFactory.newConstantNullBlock(20000)));
64+
Page output = operator.getOutput();
65+
// 10000 expected rows, stddev=sqrt(10000)=100, so this is 10 stddevs.
66+
assertThat(output.getPositionCount(), both(greaterThan(9000)).and(lessThan(11000)));
67+
totalPositionCount += output.getPositionCount();
68+
output.releaseBlocks();
69+
}
70+
71+
int averagePositionCount = totalPositionCount / 10000;
72+
// Running 10000 times, so the stddev is divided by sqrt(10000)=100, so this 10 stddevs again.
73+
assertThat(averagePositionCount, both(greaterThan(9990)).and(lessThan(10010)));
74+
}
5375
}

0 commit comments

Comments
 (0)