|
8 | 8 | package org.elasticsearch.xpack.esql.optimizer.rules.logical; |
9 | 9 |
|
10 | 10 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
11 | | -import org.elasticsearch.xpack.esql.core.util.Holder; |
| 11 | +import org.elasticsearch.xpack.esql.core.tree.Source; |
12 | 12 | import org.elasticsearch.xpack.esql.expression.function.aggregate.HasSampleCorrection; |
| 13 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; |
13 | 14 | import org.elasticsearch.xpack.esql.plan.logical.Aggregate; |
14 | 15 | import org.elasticsearch.xpack.esql.plan.logical.Limit; |
15 | 16 | import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; |
16 | 17 | import org.elasticsearch.xpack.esql.plan.logical.MvExpand; |
17 | 18 | import org.elasticsearch.xpack.esql.plan.logical.Sample; |
18 | 19 | import org.elasticsearch.xpack.esql.rule.Rule; |
19 | 20 |
|
| 21 | +import java.util.ArrayList; |
| 22 | +import java.util.List; |
| 23 | + |
20 | 24 | public class ApplySampleCorrections extends Rule<LogicalPlan, LogicalPlan> { |
21 | 25 |
|
22 | 26 | @Override |
23 | 27 | public LogicalPlan apply(LogicalPlan logicalPlan) { |
24 | | - Holder<Expression> sampleProbability = new Holder<>(null); |
| 28 | + List<Expression> sampleProbabilities = new ArrayList<>(); |
25 | 29 | return logicalPlan.transformUp(plan -> { |
26 | 30 | if (plan instanceof Sample sample) { |
27 | | - sampleProbability.set(sample.probability()); |
| 31 | + sampleProbabilities.add(sample.probability()); |
28 | 32 | } |
29 | 33 | if (plan instanceof Limit || plan instanceof MvExpand) { |
30 | | - sampleProbability.set(null); |
| 34 | + sampleProbabilities.clear(); |
31 | 35 | } |
32 | | - if (plan instanceof Aggregate && sampleProbability.get() != null) { |
| 36 | + if (plan instanceof Aggregate && sampleProbabilities.isEmpty() == false) { |
33 | 37 | plan = plan.transformExpressionsOnly( |
34 | 38 | e -> e instanceof HasSampleCorrection hsc && hsc.isSampleCorrected() == false |
35 | | - ? hsc.sampleCorrection(sampleProbability.get()) |
| 39 | + ? hsc.sampleCorrection(getSampleProbability(sampleProbabilities, e.source())) |
36 | 40 | : e |
37 | 41 | ); |
38 | | - sampleProbability.set(null); |
| 42 | + sampleProbabilities.clear(); |
39 | 43 | } |
40 | 44 | return plan; |
41 | 45 | }); |
42 | 46 | } |
| 47 | + |
| 48 | + private Expression getSampleProbability(List<Expression> sampleProbabilities, Source source) { |
| 49 | + Expression result = null; |
| 50 | + for (Expression probability : sampleProbabilities) { |
| 51 | + result = result == null ? probability : new Mul(source, result, probability); |
| 52 | + } |
| 53 | + return result; |
| 54 | + } |
43 | 55 | } |
0 commit comments