88package org .elasticsearch .xpack .esql .expression .function .scalar .approximate ;
99
1010import org .apache .commons .math3 .distribution .NormalDistribution ;
11+ import org .apache .commons .math3 .stat .descriptive .moment .Kurtosis ;
1112import org .apache .commons .math3 .stat .descriptive .moment .Mean ;
1213import org .apache .commons .math3 .stat .descriptive .moment .Skewness ;
1314import org .apache .commons .math3 .stat .descriptive .moment .StandardDeviation ;
3233import org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
3334
3435import java .io .IOException ;
36+ import java .util .ArrayList ;
3537import java .util .List ;
3638import java .util .Objects ;
3739
@@ -193,16 +195,16 @@ static void process(
193195 return ;
194196 }
195197 double bestEstimate = bestEstimateBlock .getDouble (bestEstimateBlock .getFirstValueIndex (position ));
196- double [] estimates = new double [estimatesBlock .getValueCount (position )];
197- for (int i = 0 ; i < estimatesBlock .getValueCount (position ); i ++) {
198- estimates [i ] = estimatesBlock .getDouble (estimatesBlock .getFirstValueIndex (position ) + i );
199- }
200198 int trialCount = trialCountBlock .getInt (trialCountBlock .getFirstValueIndex (position ));
201199 int bucketCount = bucketCountBlock .getInt (bucketCountBlock .getFirstValueIndex (position ));
202- if (estimates . length != trialCount * bucketCount ) {
200+ if (estimatesBlock . getValueCount ( position ) != trialCount * bucketCount ) {
203201 builder .appendNull ();
204202 return ;
205203 }
204+ double [] estimates = new double [estimatesBlock .getValueCount (position )];
205+ for (int i = 0 ; i < estimatesBlock .getValueCount (position ); i ++) {
206+ estimates [i ] = estimatesBlock .getDouble (estimatesBlock .getFirstValueIndex (position ) + i );
207+ }
206208 double confidenceLevel = confidenceLevelBlock .getDouble (confidenceLevelBlock .getFirstValueIndex (position ));
207209 double [] confidenceInterval = computeConfidenceInterval (bestEstimate , estimates , trialCount , bucketCount , confidenceLevel );
208210 if (confidenceInterval == null ) {
@@ -237,13 +239,17 @@ static double[] computeConfidenceInterval(
237239 meanZeroNaN .increment (0.0 );
238240 }
239241 }
240- if (meanIgnoreNaN .getN () >= 3 ) {
241- meansIgnoreNaN .increment (meanIgnoreNaN .getResult ());
242+ double value ;
243+ if (Double .isNaN (value = meanIgnoreNaN .getResult ()) == false ) {
244+ meansIgnoreNaN .increment (value );
242245 }
243- if (meanZeroNaN .getN () >= 3 ) {
244- meansZeroNaN .increment (meanZeroNaN . getResult () );
246+ if (Double . isNaN ( value = meanZeroNaN .getResult ()) == false ) {
247+ meansZeroNaN .increment (value );
245248 }
246249 }
250+ if (Double .isNaN (meansIgnoreNaN .getResult ()) || Double .isNaN (meansZeroNaN .getResult ())) {
251+ return null ;
252+ }
247253
248254 double meanIgnoreNan = meansIgnoreNaN .getResult ();
249255 double meanZeroNan = meansZeroNaN .getResult ();
@@ -253,42 +259,73 @@ static double[] computeConfidenceInterval(
253259
254260 Mean stddevs = new Mean ();
255261 Mean skews = new Mean ();
262+ Mean kurtoses = new Mean ();
263+ int reliableCount = 0 ;
256264 for (int trial = 0 ; trial < trialCount ; trial ++) {
257- StandardDeviation stdDev = new StandardDeviation (false );
265+ StandardDeviation stddev = new StandardDeviation (false );
258266 Skewness skew = new Skewness ();
267+ Kurtosis kurtosis = new Kurtosis ();
268+ boolean hasNans = false ;
259269 for (int bucket = 0 ; bucket < bucketCount ; bucket ++) {
260270 double estimate = estimates [trial * bucketCount + bucket ];
261271 if (Double .isNaN (estimate )) {
272+ hasNans = true ;
262273 if (ignoreNaNs ) {
263274 continue ;
264275 } else {
265276 estimate = 0.0 ;
266277 }
267278 }
268- stdDev .increment (estimate );
279+ stddev .increment (estimate );
269280 skew .increment (estimate );
281+ kurtosis .increment (estimate );
270282 }
271- if (skew .getN () >= 3 ) {
272- stddevs .increment (stdDev .getResult ());
273- skews .increment (skew .getResult ());
283+ double stddevResult = stddev .getResult ();
284+ if (Double .isNaN (stddevResult ) == false ) {
285+ stddevs .increment (stddevResult );
286+ }
287+ double skewResult = skew .getResult ();
288+ if (Double .isNaN (skewResult ) == false ) {
289+ skews .increment (skewResult );
290+ }
291+ double kurtosisResult = kurtosis .getResult ();
292+ if (Double .isNaN (kurtosisResult ) == false ) {
293+ kurtoses .increment (kurtosisResult );
294+ }
295+ if (hasNans == false && computeReliable (skewResult , kurtosisResult , bucketCount )) {
296+ reliableCount ++;
274297 }
275298 }
276299
277300 double sm = stddevs .getResult ();
301+ double skew = skews .getResult ();
302+ if (Double .isNaN (sm ) || Double .isNaN (skew )) {
303+ return null ;
304+ }
278305 if (sm == 0.0 ) {
279- return new double [] { bestEstimate , bestEstimate };
306+ return new double [] { bestEstimate , bestEstimate , 1.0 };
280307 }
281308
282309 // Scale the acceleration to account for the dependence of skewness on sample size.
283310 double scale = 1 / Math .sqrt (bucketCount );
284- double a = scale * skews . getResult () / 6.0 ;
311+ double a = scale * skew / 6.0 ;
285312 double z0 = (bestEstimate - mm ) / sm ;
286313 double dz = normal .inverseCumulativeProbability ((1.0 + confidenceLevel ) / 2.0 );
287314 double zl = z0 + (z0 - dz ) / (1.0 - Math .min (a * (z0 - dz ), 0.9 ));
288315 double zu = z0 + (z0 + dz ) / (1.0 - Math .min (a * (z0 + dz ), 0.9 ));
289316 double lower = mm + scale * sm * zl ;
290317 double upper = mm + scale * sm * zu ;
291- return lower <= bestEstimate && bestEstimate <= upper ? new double [] { lower , upper } : null ;
318+
319+ return lower <= bestEstimate && bestEstimate <= upper ? new double [] { lower , upper , (double ) reliableCount / trialCount } : null ;
320+ }
321+
322+ static boolean computeReliable (double skew , double kurtosis , int B ) {
323+ if (Double .isNaN (skew ) || Double .isNaN (kurtosis ) || B < 4 ) {
324+ return false ;
325+ }
326+ double maxSkew = Math .sqrt (6.0 * B * (B - 1 ) / ((B - 2 ) * (B + 1 ) * (B + 3 ))) * 1.96 ;
327+ double maxKurtosis = Math .sqrt (24.0 * B * (B - 1 ) * (B - 1 ) / ((B - 3 ) * (B - 2 ) * (B + 3 ) * (B + 5 ))) * 1.96 ;
328+ return Math .abs (skew ) < maxSkew && Math .abs (kurtosis ) < maxKurtosis ;
292329 }
293330
294331 @ Override
0 commit comments