Skip to content

Commit 04bce56

Browse files
committed
propagate multiple sample probabilities
1 parent 05f1151 commit 04bce56

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ApplySampleCorrections.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,48 @@
88
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
99

1010
import org.elasticsearch.xpack.esql.core.expression.Expression;
11-
import org.elasticsearch.xpack.esql.core.util.Holder;
11+
import org.elasticsearch.xpack.esql.core.tree.Source;
1212
import org.elasticsearch.xpack.esql.expression.function.aggregate.HasSampleCorrection;
13+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
1314
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
1415
import org.elasticsearch.xpack.esql.plan.logical.Limit;
1516
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1617
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
1718
import org.elasticsearch.xpack.esql.plan.logical.Sample;
1819
import org.elasticsearch.xpack.esql.rule.Rule;
1920

21+
import java.util.ArrayList;
22+
import java.util.List;
23+
2024
public class ApplySampleCorrections extends Rule<LogicalPlan, LogicalPlan> {
2125

2226
@Override
2327
public LogicalPlan apply(LogicalPlan logicalPlan) {
24-
Holder<Expression> sampleProbability = new Holder<>(null);
28+
List<Expression> sampleProbabilities = new ArrayList<>();
2529
return logicalPlan.transformUp(plan -> {
2630
if (plan instanceof Sample sample) {
27-
sampleProbability.set(sample.probability());
31+
sampleProbabilities.add(sample.probability());
2832
}
2933
if (plan instanceof Limit || plan instanceof MvExpand) {
30-
sampleProbability.set(null);
34+
sampleProbabilities.clear();
3135
}
32-
if (plan instanceof Aggregate && sampleProbability.get() != null) {
36+
if (plan instanceof Aggregate && sampleProbabilities.isEmpty() == false) {
3337
plan = plan.transformExpressionsOnly(
3438
e -> e instanceof HasSampleCorrection hsc && hsc.isSampleCorrected() == false
35-
? hsc.sampleCorrection(sampleProbability.get())
39+
? hsc.sampleCorrection(getSampleProbability(sampleProbabilities, e.source()))
3640
: e
3741
);
38-
sampleProbability.set(null);
42+
sampleProbabilities.clear();
3943
}
4044
return plan;
4145
});
4246
}
47+
48+
private Expression getSampleProbability(List<Expression> sampleProbabilities, Source source) {
49+
Expression result = null;
50+
for (Expression probability : sampleProbabilities) {
51+
result = result == null ? probability : new Mul(source, result, probability);
52+
}
53+
return result;
54+
}
4355
}

0 commit comments

Comments
 (0)