Skip to content

Commit 141d083

Browse files
committed
seperate confidence interval column + fix to_string/date etc
1 parent 62550a5 commit 141d083

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/approximate/Approximate.java

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
import org.elasticsearch.xpack.esql.core.expression.Literal;
2020
import org.elasticsearch.xpack.esql.core.expression.NameId;
2121
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
22+
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
2223
import org.elasticsearch.xpack.esql.core.tree.Source;
2324
import org.elasticsearch.xpack.esql.core.type.DataType;
2425
import org.elasticsearch.xpack.esql.core.util.Holder;
2526
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
2627
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
2728
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
2829
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
30+
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
2931
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
3032
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.ConfidenceInterval;
3133
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend;
@@ -300,6 +302,7 @@ private LogicalPlan approximatePlan(double sampleProbability) {
300302
logger.debug("generating approximate plan (p={})", sampleProbability);
301303
Holder<Boolean> encounteredStats = new Holder<>(false);
302304
Set<NameId> variablesWithConfidenceInterval = new HashSet<>();
305+
Set<NameId> variablesWithPastConfidenceInterval = new HashSet<>();
303306

304307
Alias bucketId = new Alias(
305308
Source.EMPTY,
@@ -345,21 +348,33 @@ private LogicalPlan approximatePlan(double sampleProbability) {
345348
} else if (encounteredStats.get()) {
346349
System.out.println("@@@ UPDATE variablesWithConfidenceInterval");
347350
System.out.println("plan = " + plan);
348-
System.out.println("vars = " + variablesWithConfidenceInterval);
351+
System.out.println("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval);
349352
switch (plan) {
350353
case Eval eval:
351-
for (NamedExpression field : eval.fields()) {
354+
for (Alias field : eval.fields()) {
352355
if (field.anyMatch(expr -> expr instanceof NamedExpression named && variablesWithConfidenceInterval.contains(named.id()))) {
353-
variablesWithConfidenceInterval.add(field.id());
356+
// TODO: blacklist / whitelist?
357+
if (field.child() instanceof MvAppend == false && field.dataType().isNumeric()) {
358+
variablesWithConfidenceInterval.add(field.id());
359+
} else {
360+
variablesWithPastConfidenceInterval.add(field.id());
361+
}
362+
} else if (field.anyMatch(expr -> expr instanceof NamedExpression named && variablesWithPastConfidenceInterval.contains(named.id()))) {
363+
variablesWithPastConfidenceInterval.add(field.id());
354364
}
355365
}
356366
break;
367+
case Project project:
368+
List<NamedExpression> projections = new ArrayList<>(project.projections());
369+
projections.add(bucketId.toAttribute());
370+
plan = project.withProjections(projections);
371+
break;
357372
case Rename rename:
358373
// TODO
359374
break;
360375
default:
361376
}
362-
System.out.println("vars = " + variablesWithConfidenceInterval);
377+
System.out.println("vars = " + variablesWithConfidenceInterval + " / " + variablesWithPastConfidenceInterval);
363378
}
364379
return plan;
365380
});
@@ -372,17 +387,22 @@ private LogicalPlan approximatePlan(double sampleProbability) {
372387
if (attribute.id() == bucketId.id()) {
373388
continue;
374389
}
375-
if (variablesWithConfidenceInterval.contains(attribute.id())) {
376-
aggregates.add(new Alias(
390+
if (variablesWithConfidenceInterval.contains(attribute.id()) || variablesWithPastConfidenceInterval.contains(attribute.id())) {
391+
Alias bestEstimate = new Alias(
377392
Source.EMPTY,
378393
attribute.name(),
379-
new ConfidenceInterval(
394+
new Values(
380395
Source.EMPTY,
381-
new Min(
382-
Source.EMPTY,
383-
attribute,
384-
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1))
385-
),
396+
attribute,
397+
new Equals(Source.EMPTY, bucketId.toAttribute(), Literal.integer(Source.EMPTY, -1))
398+
)
399+
);
400+
aggregates.add(bestEstimate);
401+
if (variablesWithConfidenceInterval.contains(attribute.id())) {
402+
aggregates.add(new Alias(
403+
Source.EMPTY, "CONFIDENCE_INTERVAL(" + attribute.name() + ")", new ConfidenceInterval(
404+
Source.EMPTY,
405+
bestEstimate.toAttribute(),
386406
new Top(
387407
Source.EMPTY,
388408
attribute,
@@ -391,9 +411,11 @@ private LogicalPlan approximatePlan(double sampleProbability) {
391411
Literal.keyword(Source.EMPTY, "ASC")
392412
),
393413
Literal.integer(Source.EMPTY, BUCKET_COUNT),
394-
Literal.fromDouble(Source.EMPTY, 0.0) // TODO: fix, 0.0 or NaN ??
414+
Literal.fromDouble(Source.EMPTY, 0.0)
415+
// TODO: fix, 0.0 or NaN ?? TODO: remove!!
395416
)
396-
));
417+
));
418+
}
397419
} else {
398420
aggregates.add(attribute);
399421
groupings.add(attribute);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/ConfidenceInterval.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ private static Number[] computeConfidenceInterval(Number bestEstimate, Number[]
249249

250250
return new Number[] {
251251
mm + sm * (z0 + zl / (1 - Math.min(0.8, a * zl))),
252-
bestEstimate,
253252
mm + sm * (z0 + zu / (1 - Math.min(0.8, a * zu))),
254253
};
255254
}

0 commit comments

Comments
 (0)