Skip to content

Commit 6963ba7

Browse files
committed
merge confidence interval and reliable
1 parent 0dbbb66 commit 6963ba7

File tree

6 files changed

+95
-363
lines changed

6 files changed

+95
-363
lines changed

x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ReliableEvaluator.java

Lines changed: 0 additions & 144 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
3737
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.ConfidenceInterval;
3838
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.Random;
39-
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.Reliable;
4039
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
4140
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
4241
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -744,13 +743,19 @@ private LogicalPlan approximatePlan(double sampleProbability) {
744743
default -> throw new IllegalStateException("unexpected data type [" + output.dataType() + "]");
745744
};
746745
confidenceIntervalsAndReliable.add(
747-
new Alias(Source.EMPTY, "CONFIDENCE_INTERVAL(" + output.name() + ")", confidenceInterval)
746+
new Alias(Source.EMPTY, "CONFIDENCE_INTERVAL(" + output.name() + ")",
747+
new MvSlice(Source.EMPTY, confidenceInterval, Literal.integer(Source.EMPTY, 0), Literal.integer(Source.EMPTY, 1))
748+
)
748749
);
749750
confidenceIntervalsAndReliable.add(
750751
new Alias(
751752
Source.EMPTY,
752753
"RELIABLE(" + output.name() + ")",
753-
new Reliable(Source.EMPTY, bucketsMv, trialCount, bucketCount)
754+
new GreaterThanOrEqual(
755+
Source.EMPTY,
756+
new MvSlice(Source.EMPTY, confidenceInterval, Literal.integer(Source.EMPTY, 2), Literal.integer(Source.EMPTY, 2)),
757+
Literal.fromDouble(Source.EMPTY, 0.5)
758+
)
754759
)
755760
);
756761
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ScalarFunctionWritables.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingWritables;
1212
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.ConfidenceInterval;
13-
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.Reliable;
1413
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
1514
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.ClampMax;
1615
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.ClampMin;
@@ -108,7 +107,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
108107
entries.add(Or.ENTRY);
109108
entries.add(Pi.ENTRY);
110109
entries.add(Pow.ENTRY);
111-
entries.add(Reliable.ENTRY);
112110
entries.add(Right.ENTRY);
113111
entries.add(Repeat.ENTRY);
114112
entries.add(Replace.ENTRY);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/approximate/ConfidenceInterval.java

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.expression.function.scalar.approximate;
99

1010
import org.apache.commons.math3.distribution.NormalDistribution;
11+
import org.apache.commons.math3.stat.descriptive.moment.Kurtosis;
1112
import org.apache.commons.math3.stat.descriptive.moment.Mean;
1213
import org.apache.commons.math3.stat.descriptive.moment.Skewness;
1314
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
@@ -32,6 +33,7 @@
3233
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
3334

3435
import java.io.IOException;
36+
import java.util.ArrayList;
3537
import java.util.List;
3638
import java.util.Objects;
3739

@@ -193,16 +195,16 @@ static void process(
193195
return;
194196
}
195197
double bestEstimate = bestEstimateBlock.getDouble(bestEstimateBlock.getFirstValueIndex(position));
196-
double[] estimates = new double[estimatesBlock.getValueCount(position)];
197-
for (int i = 0; i < estimatesBlock.getValueCount(position); i++) {
198-
estimates[i] = estimatesBlock.getDouble(estimatesBlock.getFirstValueIndex(position) + i);
199-
}
200198
int trialCount = trialCountBlock.getInt(trialCountBlock.getFirstValueIndex(position));
201199
int bucketCount = bucketCountBlock.getInt(bucketCountBlock.getFirstValueIndex(position));
202-
if (estimates.length != trialCount * bucketCount) {
200+
if (estimatesBlock.getValueCount(position) != trialCount * bucketCount) {
203201
builder.appendNull();
204202
return;
205203
}
204+
double[] estimates = new double[estimatesBlock.getValueCount(position)];
205+
for (int i = 0; i < estimatesBlock.getValueCount(position); i++) {
206+
estimates[i] = estimatesBlock.getDouble(estimatesBlock.getFirstValueIndex(position) + i);
207+
}
206208
double confidenceLevel = confidenceLevelBlock.getDouble(confidenceLevelBlock.getFirstValueIndex(position));
207209
double[] confidenceInterval = computeConfidenceInterval(bestEstimate, estimates, trialCount, bucketCount, confidenceLevel);
208210
if (confidenceInterval == null) {
@@ -237,13 +239,17 @@ static double[] computeConfidenceInterval(
237239
meanZeroNaN.increment(0.0);
238240
}
239241
}
240-
if (meanIgnoreNaN.getN() >= 3) {
241-
meansIgnoreNaN.increment(meanIgnoreNaN.getResult());
242+
double value;
243+
if (Double.isNaN(value = meanIgnoreNaN.getResult()) == false) {
244+
meansIgnoreNaN.increment(value);
242245
}
243-
if (meanZeroNaN.getN() >= 3) {
244-
meansZeroNaN.increment(meanZeroNaN.getResult());
246+
if (Double.isNaN(value = meanZeroNaN.getResult()) == false) {
247+
meansZeroNaN.increment(value);
245248
}
246249
}
250+
if (Double.isNaN(meansIgnoreNaN.getResult()) || Double.isNaN(meansZeroNaN.getResult())) {
251+
return null;
252+
}
247253

248254
double meanIgnoreNan = meansIgnoreNaN.getResult();
249255
double meanZeroNan = meansZeroNaN.getResult();
@@ -253,42 +259,73 @@ static double[] computeConfidenceInterval(
253259

254260
Mean stddevs = new Mean();
255261
Mean skews = new Mean();
262+
Mean kurtoses = new Mean();
263+
int reliableCount = 0;
256264
for (int trial = 0; trial < trialCount; trial++) {
257-
StandardDeviation stdDev = new StandardDeviation(false);
265+
StandardDeviation stddev = new StandardDeviation(false);
258266
Skewness skew = new Skewness();
267+
Kurtosis kurtosis = new Kurtosis();
268+
boolean hasNans = false;
259269
for (int bucket = 0; bucket < bucketCount; bucket++) {
260270
double estimate = estimates[trial * bucketCount + bucket];
261271
if (Double.isNaN(estimate)) {
272+
hasNans = true;
262273
if (ignoreNaNs) {
263274
continue;
264275
} else {
265276
estimate = 0.0;
266277
}
267278
}
268-
stdDev.increment(estimate);
279+
stddev.increment(estimate);
269280
skew.increment(estimate);
281+
kurtosis.increment(estimate);
270282
}
271-
if (skew.getN() >= 3) {
272-
stddevs.increment(stdDev.getResult());
273-
skews.increment(skew.getResult());
283+
double stddevResult = stddev.getResult();
284+
if (Double.isNaN(stddevResult) == false) {
285+
stddevs.increment(stddevResult);
286+
}
287+
double skewResult = skew.getResult();
288+
if (Double.isNaN(skewResult) == false) {
289+
skews.increment(skewResult);
290+
}
291+
double kurtosisResult = kurtosis.getResult();
292+
if (Double.isNaN(kurtosisResult) == false) {
293+
kurtoses.increment(kurtosisResult);
294+
}
295+
if (hasNans == false && computeReliable(skewResult, kurtosisResult, bucketCount)) {
296+
reliableCount++;
274297
}
275298
}
276299

277300
double sm = stddevs.getResult();
301+
double skew = skews.getResult();
302+
if (Double.isNaN(sm) || Double.isNaN(skew)) {
303+
return null;
304+
}
278305
if (sm == 0.0) {
279-
return new double[] { bestEstimate, bestEstimate };
306+
return new double[] { bestEstimate, bestEstimate, 1.0 };
280307
}
281308

282309
// Scale the acceleration to account for the dependence of skewness on sample size.
283310
double scale = 1 / Math.sqrt(bucketCount);
284-
double a = scale * skews.getResult() / 6.0;
311+
double a = scale * skew / 6.0;
285312
double z0 = (bestEstimate - mm) / sm;
286313
double dz = normal.inverseCumulativeProbability((1.0 + confidenceLevel) / 2.0);
287314
double zl = z0 + (z0 - dz) / (1.0 - Math.min(a * (z0 - dz), 0.9));
288315
double zu = z0 + (z0 + dz) / (1.0 - Math.min(a * (z0 + dz), 0.9));
289316
double lower = mm + scale * sm * zl;
290317
double upper = mm + scale * sm * zu;
291-
return lower <= bestEstimate && bestEstimate <= upper ? new double[] { lower, upper } : null;
318+
319+
return lower <= bestEstimate && bestEstimate <= upper ? new double[] { lower, upper, (double) reliableCount / trialCount } : null;
320+
}
321+
322+
static boolean computeReliable(double skew, double kurtosis, int B) {
323+
if (Double.isNaN(skew) || Double.isNaN(kurtosis) || B < 4) {
324+
return false;
325+
}
326+
double maxSkew = Math.sqrt(6.0 * B * (B - 1) / ((B - 2) * (B + 1) * (B + 3))) * 1.96;
327+
double maxKurtosis = Math.sqrt(24.0 * B * (B - 1) * (B - 1) / ((B - 3) * (B - 2) * (B + 3) * (B + 5))) * 1.96;
328+
return Math.abs(skew) < maxSkew && Math.abs(kurtosis) < maxKurtosis;
292329
}
293330

294331
@Override

0 commit comments

Comments
 (0)