Skip to content

Commit 0dbbb66

Browse files
committed
Improve confidence interval computation with NaNs + unit tests
1 parent ffea133 commit 0dbbb66

File tree

2 files changed

+224
-15
lines changed

2 files changed

+224
-15
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
3333

3434
import java.io.IOException;
35-
import java.util.Arrays;
3635
import java.util.List;
3736
import java.util.Objects;
3837

@@ -200,6 +199,10 @@ static void process(
200199
}
201200
int trialCount = trialCountBlock.getInt(trialCountBlock.getFirstValueIndex(position));
202201
int bucketCount = bucketCountBlock.getInt(bucketCountBlock.getFirstValueIndex(position));
202+
if (estimates.length != trialCount * bucketCount) {
203+
builder.appendNull();
204+
return;
205+
}
203206
double confidenceLevel = confidenceLevelBlock.getDouble(confidenceLevelBlock.getFirstValueIndex(position));
204207
double[] confidenceInterval = computeConfidenceInterval(bestEstimate, estimates, trialCount, bucketCount, confidenceLevel);
205208
if (confidenceInterval == null) {
@@ -213,50 +216,79 @@ static void process(
213216
}
214217
}
215218

216-
public static double[] computeConfidenceInterval(
219+
static double[] computeConfidenceInterval(
217220
double bestEstimate,
218221
double[] estimates,
219222
int trialCount,
220223
int bucketCount,
221224
double confidenceLevel
222225
) {
223-
Mean means = new Mean();
226+
Mean meansIgnoreNaN = new Mean();
227+
Mean meansZeroNaN = new Mean();
228+
for (int trial = 0; trial < trialCount; trial++) {
229+
Mean meanIgnoreNaN = new Mean();
230+
Mean meanZeroNaN = new Mean();
231+
for (int bucket = 0; bucket < bucketCount; bucket++) {
232+
double estimate = estimates[trial * bucketCount + bucket];
233+
if (Double.isNaN(estimate) == false) {
234+
meanIgnoreNaN.increment(estimate);
235+
meanZeroNaN.increment(estimate);
236+
} else {
237+
meanZeroNaN.increment(0.0);
238+
}
239+
}
240+
if (meanIgnoreNaN.getN() >= 3) {
241+
meansIgnoreNaN.increment(meanIgnoreNaN.getResult());
242+
}
243+
if (meanZeroNaN.getN() >= 3) {
244+
meansZeroNaN.increment(meanZeroNaN.getResult());
245+
}
246+
}
247+
248+
double meanIgnoreNan = meansIgnoreNaN.getResult();
249+
double meanZeroNan = meansZeroNaN.getResult();
250+
251+
boolean ignoreNaNs = Math.abs(meanIgnoreNan - bestEstimate) < Math.abs(meanZeroNan - bestEstimate);
252+
double mm = ignoreNaNs ? meanIgnoreNan : meanZeroNan;
253+
224254
Mean stddevs = new Mean();
225255
Mean skews = new Mean();
226256
for (int trial = 0; trial < trialCount; trial++) {
227-
Mean mean = new Mean();
228257
StandardDeviation stdDev = new StandardDeviation(false);
229258
Skewness skew = new Skewness();
230259
for (int bucket = 0; bucket < bucketCount; bucket++) {
231260
double estimate = estimates[trial * bucketCount + bucket];
232261
if (Double.isNaN(estimate)) {
233-
continue;
262+
if (ignoreNaNs) {
263+
continue;
264+
} else {
265+
estimate = 0.0;
266+
}
234267
}
235-
mean.increment(estimate);
236268
stdDev.increment(estimate);
237269
skew.increment(estimate);
238270
}
239271
if (skew.getN() >= 3) {
240-
means.increment(mean.getResult());
241272
stddevs.increment(stdDev.getResult());
242273
skews.increment(skew.getResult());
243274
}
244275
}
245-
if (means.getN() == 0) {
246-
return null;
247-
}
276+
248277
double sm = stddevs.getResult();
249278
if (sm == 0.0) {
250279
return new double[] { bestEstimate, bestEstimate };
251280
}
252-
double mm = means.getResult();
253-
double a = skews.getResult() / (6.0 * Math.sqrt(bucketCount));
281+
282+
// Scale the acceleration to account for the dependence of skewness on sample size.
283+
double scale = 1 / Math.sqrt(bucketCount);
284+
double a = scale * skews.getResult() / 6.0;
254285
double z0 = (bestEstimate - mm) / sm;
255286
double dz = normal.inverseCumulativeProbability((1.0 + confidenceLevel) / 2.0);
256287
double zl = z0 + (z0 - dz) / (1.0 - Math.min(a * (z0 - dz), 0.9));
257288
double zu = z0 + (z0 + dz) / (1.0 - Math.min(a * (z0 + dz), 0.9));
258-
double scale = Math.max(1.0 / Math.sqrt(bucketCount), z0 < 0.0 ? z0 / zl : z0 / zu);
259-
return new double[] { mm + scale * sm * zl, mm + sm * scale * zu };
289+
double lower = mm + scale * sm * zl;
290+
double upper = mm + scale * sm * zu;
291+
return lower <= bestEstimate && bestEstimate <= upper ? new double[] { lower, upper } : null;
260292
}
261293

262294
@Override
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,181 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
18
package org.elasticsearch.xpack.esql.expression.function.scalar.approximate;
29

3-
public class ConfidenceIntervalTests {
10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
12+
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.tree.Source;
15+
import org.elasticsearch.xpack.esql.core.type.DataType;
16+
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
17+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.function.Supplier;
22+
import java.util.stream.IntStream;
23+
24+
import static java.lang.Double.NaN;
25+
import static org.hamcrest.Matchers.both;
26+
import static org.hamcrest.Matchers.closeTo;
27+
import static org.hamcrest.Matchers.contains;
28+
import static org.hamcrest.Matchers.greaterThan;
29+
import static org.hamcrest.Matchers.lessThan;
30+
import static org.hamcrest.Matchers.nullValue;
31+
32+
public class ConfidenceIntervalTests extends AbstractScalarFunctionTestCase {
33+
34+
@ParametersFactory
35+
public static Iterable<Object[]> parameters() {
36+
List<TestCaseSupplier> suppliers = new ArrayList<>();
37+
suppliers.add(randomBuckets());
38+
suppliers.add(allBucketsFilled());
39+
suppliers.add(nanBuckets_ignoreNan());
40+
suppliers.add(nanBuckets_zeroNan());
41+
return parameterSuppliersFromTypedDataWithDefaultChecks(false, suppliers);
42+
}
43+
44+
public ConfidenceIntervalTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
45+
this.testCase = testCaseSupplier.get();
46+
}
47+
48+
@Override
49+
protected Expression build(Source source, List<Expression> args) {
50+
return new ConfidenceInterval(source, args.get(0), args.get(1), args.get(2), args.get(3), args.get(4));
51+
}
52+
53+
private static TestCaseSupplier randomBuckets() {
54+
return new TestCaseSupplier(
55+
"randomBuckets",
56+
List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.INTEGER, DataType.INTEGER, DataType.DOUBLE),
57+
() -> {
58+
int trialCount = randomIntBetween(1, 10);
59+
int bucketCount = randomIntBetween(3, 10);
60+
double confidenceLevel = randomDoubleBetween(0.8, 0.95, true);
61+
double bestEstimate = bucketCount / 2.0;
62+
List<Double> estimates = IntStream.range(0, trialCount * bucketCount)
63+
.mapToDouble(i -> randomDoubleBetween((i % bucketCount), (i % bucketCount) + 1, true))
64+
.boxed()
65+
.toList();
66+
return new TestCaseSupplier.TestCase(
67+
List.of(
68+
new TestCaseSupplier.TypedData(bestEstimate, DataType.DOUBLE, "bestEstimate"),
69+
new TestCaseSupplier.TypedData(estimates, DataType.DOUBLE, "estimates"),
70+
new TestCaseSupplier.TypedData(trialCount, DataType.INTEGER, "trialCount"),
71+
new TestCaseSupplier.TypedData(bucketCount, DataType.INTEGER, "bucketCount"),
72+
new TestCaseSupplier.TypedData(confidenceLevel, DataType.DOUBLE, "confidenceLevel")
73+
),
74+
"ConfidenceIntervalEvaluator[bestEstimateBlock=Attribute[channel=0], estimatesBlock=Attribute[channel=1], trialCountBlock="
75+
+ "Attribute[channel=2], bucketCountBlock=Attribute[channel=3], confidenceLevelBlock=Attribute[channel=4]]",
76+
DataType.DOUBLE,
77+
contains(
78+
both(greaterThan(0.0)).and(lessThan(bestEstimate)),
79+
both(greaterThan(bestEstimate)).and(lessThan((double) bucketCount))
80+
)
81+
);
82+
}
83+
);
84+
}
85+
86+
private static TestCaseSupplier allBucketsFilled() {
87+
return new TestCaseSupplier(
88+
"allBucketsFilled",
89+
List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.INTEGER, DataType.INTEGER, DataType.DOUBLE),
90+
() -> new TestCaseSupplier.TestCase(
91+
List.of(
92+
new TestCaseSupplier.TypedData(2.0, DataType.DOUBLE, "bestEstimate"),
93+
new TestCaseSupplier.TypedData(
94+
List.of(2.15, 1.73, 2.1, 2.49, 2.41, 2.06, 2.29, 1.97, 1.54, 1.97, 2.41, 1.75, 1.55, 2.33, 1.64),
95+
DataType.DOUBLE,
96+
"estimates"
97+
),
98+
new TestCaseSupplier.TypedData(3, DataType.INTEGER, "trialCount"),
99+
new TestCaseSupplier.TypedData(5, DataType.INTEGER, "bucketCount"),
100+
new TestCaseSupplier.TypedData(0.8, DataType.DOUBLE, "confidence_level")
101+
),
102+
"ConfidenceIntervalEvaluator[bestEstimateBlock=Attribute[channel=0], estimatesBlock=Attribute[channel=1], trialCountBlock="
103+
+ "Attribute[channel=2], bucketCountBlock=Attribute[channel=3], confidenceLevelBlock=Attribute[channel=4]]",
104+
DataType.DOUBLE,
105+
contains(closeTo(1.8293144967855208, 1e-9), closeTo(2.164428203663303, 1e-9))
106+
)
107+
);
108+
}
109+
110+
private static TestCaseSupplier nanBuckets_ignoreNan() {
111+
return new TestCaseSupplier(
112+
"nanBuckets_ignoreNan",
113+
List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.INTEGER, DataType.INTEGER, DataType.DOUBLE),
114+
() -> new TestCaseSupplier.TestCase(
115+
List.of(
116+
new TestCaseSupplier.TypedData(2.0, DataType.DOUBLE, "bestEstimate"),
117+
new TestCaseSupplier.TypedData(
118+
List.of(2.15, NaN, NaN, 2.49, 2.41, NaN, 2.29, NaN, 1.54, 1.97, 2.41, NaN, 1.55, NaN, 1.64),
119+
DataType.DOUBLE,
120+
"estimates"
121+
),
122+
new TestCaseSupplier.TypedData(3, DataType.INTEGER, "trialCount"),
123+
new TestCaseSupplier.TypedData(5, DataType.INTEGER, "bucketCount"),
124+
new TestCaseSupplier.TypedData(0.8, DataType.DOUBLE, "confidence_level")
125+
),
126+
"ConfidenceIntervalEvaluator[bestEstimateBlock=Attribute[channel=0], estimatesBlock=Attribute[channel=1], trialCountBlock="
127+
+ "Attribute[channel=2], bucketCountBlock=Attribute[channel=3], confidenceLevelBlock=Attribute[channel=4]]",
128+
DataType.DOUBLE,
129+
contains(closeTo(1.8443260740876288, 1e-9), closeTo(2.164997868635109, 1e-9))
130+
)
131+
);
132+
}
133+
134+
private static TestCaseSupplier nanBuckets_zeroNan() {
135+
return new TestCaseSupplier(
136+
"nanBuckets_zeroNan",
137+
List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.INTEGER, DataType.INTEGER, DataType.DOUBLE),
138+
() -> new TestCaseSupplier.TestCase(
139+
List.of(
140+
new TestCaseSupplier.TypedData(1.0, DataType.DOUBLE, "bestEstimate"),
141+
new TestCaseSupplier.TypedData(
142+
List.of(2.15, NaN, NaN, 2.49, 2.41, NaN, 2.29, NaN, 1.54, 1.97, 2.41, NaN, 1.55, NaN, 1.64),
143+
DataType.DOUBLE,
144+
"estimates"
145+
),
146+
new TestCaseSupplier.TypedData(3, DataType.INTEGER, "trialCount"),
147+
new TestCaseSupplier.TypedData(5, DataType.INTEGER, "bucketCount"),
148+
new TestCaseSupplier.TypedData(0.8, DataType.DOUBLE, "confidence_level")
149+
),
150+
"ConfidenceIntervalEvaluator[bestEstimateBlock=Attribute[channel=0], estimatesBlock=Attribute[channel=1], trialCountBlock="
151+
+ "Attribute[channel=2], bucketCountBlock=Attribute[channel=3], confidenceLevelBlock=Attribute[channel=4]]",
152+
DataType.DOUBLE,
153+
contains(closeTo(0.4041519539094244, 1e-9), closeTo(1.6023321533418913, 1e-9))
154+
)
155+
);
156+
}
157+
158+
private static TestCaseSupplier inconsistentData() {
159+
return new TestCaseSupplier(
160+
"nanBuckets_zeroNan",
161+
List.of(DataType.DOUBLE, DataType.DOUBLE, DataType.INTEGER, DataType.INTEGER, DataType.DOUBLE),
162+
() -> new TestCaseSupplier.TestCase(
163+
List.of(
164+
new TestCaseSupplier.TypedData(123.456, DataType.DOUBLE, "bestEstimate"),
165+
new TestCaseSupplier.TypedData(
166+
List.of(2.15, NaN, NaN, 2.49, 2.41, NaN, 2.29, NaN, 1.54, 1.97, 2.41, NaN, 1.55, NaN, 1.64),
167+
DataType.DOUBLE,
168+
"estimates"
169+
),
170+
new TestCaseSupplier.TypedData(3, DataType.INTEGER, "trialCount"),
171+
new TestCaseSupplier.TypedData(5, DataType.INTEGER, "bucketCount"),
172+
new TestCaseSupplier.TypedData(0.8, DataType.DOUBLE, "confidence_level")
173+
),
174+
"ConfidenceIntervalEvaluator[bestEstimateBlock=Attribute[channel=0], estimatesBlock=Attribute[channel=1], trialCountBlock="
175+
+ "Attribute[channel=2], bucketCountBlock=Attribute[channel=3], confidenceLevelBlock=Attribute[channel=4]]",
176+
DataType.DOUBLE,
177+
nullValue()
178+
)
179+
);
180+
}
4181
}

0 commit comments

Comments
 (0)