Skip to content

Commit 0954e9d

Browse files
committed
move+hide+improve confidence interval computation
1 parent f658283 commit 0954e9d

File tree

7 files changed

+19
-27
lines changed

7 files changed

+19
-27
lines changed
Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
3232
import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg;
3333
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
34+
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.ConfidenceInterval;
3435
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
35-
import org.elasticsearch.xpack.esql.expression.function.scalar.math.ConfidenceInterval;
3636
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
3737
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains;
3838
import org.elasticsearch.xpack.esql.expression.function.scalar.random.Random;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@
118118
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Atan2;
119119
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cbrt;
120120
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Ceil;
121-
import org.elasticsearch.xpack.esql.expression.function.scalar.math.ConfidenceInterval;
122121
import org.elasticsearch.xpack.esql.expression.function.scalar.math.CopySign;
123122
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cos;
124123
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cosh;
@@ -371,7 +370,6 @@ private static FunctionDefinition[][] functions() {
371370
def(Atan2.class, Atan2::new, "atan2"),
372371
def(Cbrt.class, Cbrt::new, "cbrt"),
373372
def(Ceil.class, Ceil::new, "ceil"),
374-
def(ConfidenceInterval.class, ConfidenceInterval::new, "confidence_interval"),
375373
def(Cos.class, Cos::new, "cos"),
376374
def(Cosh.class, Cosh::new, "cosh"),
377375
def(E.class, E::new, "e"),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ScalarFunctionWritables.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingWritables;
12+
import org.elasticsearch.xpack.esql.expression.function.scalar.approximate.ConfidenceInterval;
1213
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
1314
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.ClampMax;
1415
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.ClampMin;
@@ -26,7 +27,6 @@
2627
import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch;
2728
import org.elasticsearch.xpack.esql.expression.function.scalar.ip.IpPrefix;
2829
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Atan2;
29-
import org.elasticsearch.xpack.esql.expression.function.scalar.math.ConfidenceInterval;
3030
import org.elasticsearch.xpack.esql.expression.function.scalar.math.CopySign;
3131
import org.elasticsearch.xpack.esql.expression.function.scalar.math.E;
3232
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Hypot;
Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.esql.expression.function.scalar.math;
8+
package org.elasticsearch.xpack.esql.expression.function.scalar.approximate;
99

1010
import org.apache.commons.math3.distribution.NormalDistribution;
1111
import org.apache.commons.math3.stat.descriptive.moment.Mean;
@@ -49,6 +49,8 @@ public class ConfidenceInterval extends EsqlScalarFunction {
4949
ConfidenceInterval::new
5050
);
5151

52+
private static final NormalDistribution normal = new NormalDistribution();
53+
5254
private final Expression bestEstimate;
5355
private final Expression estimates;
5456

@@ -174,7 +176,7 @@ static void process(IntBlock.Builder builder, @Position int position, IntBlock b
174176
Number[] confidenceInterval = computeConfidenceInterval(bestEstimate, estimates);
175177
builder.beginPositionEntry();
176178
for (Number v : confidenceInterval) {
177-
builder.appendInt(v.intValue());
179+
builder.appendInt((int) Math.round(v.doubleValue()));
178180
}
179181
builder.endPositionEntry();
180182
}
@@ -195,12 +197,12 @@ static void process(LongBlock.Builder builder, @Position int position, LongBlock
195197
Number[] confidenceInterval = computeConfidenceInterval(bestEstimate, estimates);
196198
builder.beginPositionEntry();
197199
for (Number v : confidenceInterval) {
198-
builder.appendLong(v.longValue());
200+
builder.appendLong(Math.round(v.doubleValue()));
199201
}
200202
builder.endPositionEntry();
201203
}
202204

203-
private static Number[] computeConfidenceInterval(Number bestEstimate, Number[] estimates) {
205+
public static Number[] computeConfidenceInterval(Number bestEstimate, Number[] estimates) {
204206
Mean estimatesMean = new Mean();
205207
StandardDeviation estimatesStdDev = new StandardDeviation(false);
206208
Skewness estimatesSkew = new Skewness();
@@ -209,26 +211,18 @@ private static Number[] computeConfidenceInterval(Number bestEstimate, Number[]
209211
estimatesStdDev.increment(estimate.doubleValue());
210212
estimatesSkew.increment(estimate.doubleValue());
211213
}
212-
213-
double mm = estimatesMean.getResult();
214214
double sm = estimatesStdDev.getResult();
215-
216215
if (sm == 0.0) {
217-
return new Number[] { bestEstimate, bestEstimate, bestEstimate };
216+
return new Number[] { bestEstimate, bestEstimate };
218217
}
219-
220-
double a = estimatesSkew.getResult() / 6;
221-
222-
NormalDistribution norm = new NormalDistribution(0, 1);
223-
218+
double mm = estimatesMean.getResult();
224219
double z0 = (bestEstimate.doubleValue() - mm) / sm;
225-
double dz = norm.inverseCumulativeProbability((1 + 0.95) / 2); // for 95% confidence interval
226-
double zl = z0 - dz;
227-
double zu = z0 + dz;
228-
229-
sm /= Math.sqrt(estimatesMean.getN());
230-
231-
return new Number[] { mm + sm * (z0 + zl / (1 - Math.min(0.8, a * zl))), mm + sm * (z0 + zu / (1 - Math.min(0.8, a * zu))), };
220+
double dz = normal.inverseCumulativeProbability((1 + 0.95) / 2); // for 95% confidence interval; TODO make configurable
221+
double a = estimatesSkew.getResult() / (6 * Math.sqrt(estimates.length));
222+
double zl = z0 + (z0 - dz) / (1 - Math.min(a * (z0 - dz), 0.9));
223+
double zu = z0 + (z0 + dz) / (1 - Math.min(a * (z0 + dz), 0.9));
224+
double scale = Math.max(1 / Math.sqrt(estimates.length), z0 < 0 ? z0 / zl : z0 / zu);
225+
return new Number[] { mm + scale * sm * zl, mm + sm * scale * zu };
232226
}
233227

234228
@Override

0 commit comments

Comments
 (0)