|
8 | 8 | package org.elasticsearch.xpack.esql.optimizer; |
9 | 9 |
|
10 | 10 | import org.elasticsearch.action.ActionListener; |
| 11 | +import org.elasticsearch.action.support.CountDownActionListener; |
11 | 12 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
12 | 13 | import org.elasticsearch.xpack.esql.core.expression.FoldContext; |
13 | 14 | import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; |
|
16 | 17 | import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; |
17 | 18 | import org.elasticsearch.xpack.esql.plugin.TransportActionServices; |
18 | 19 |
|
| 20 | +import java.util.ArrayList; |
| 21 | +import java.util.HashMap; |
| 22 | +import java.util.List; |
| 23 | +import java.util.Map; |
| 24 | + |
19 | 25 | /** |
20 | 26 | * The class is responsible for invoking any steps that need to be applied to the logical plan, |
21 | 27 | * before this is being optimized. |
@@ -56,7 +62,33 @@ private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldC |
56 | 62 | } |
57 | 63 |
|
58 | 64 | private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) { |
59 | | - plan.transformExpressionsUp(InferenceFunction.class, this::foldInferenceFunction, listener); |
| 65 | + // First let's collect all the inference functions |
| 66 | + List<InferenceFunction<?>> inferenceFunctions = new ArrayList<>(); |
| 67 | + plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add); |
| 68 | + |
| 69 | + if (inferenceFunctions.isEmpty()) { |
| 70 | + // No inference functions found. Return the original plan. |
| 71 | + listener.onResponse(plan); |
| 72 | + return; |
| 73 | + } |
| 74 | + |
| 75 | + // This is a map of inference functions to their results. |
| 76 | + // We will use this map to replace the inference functions in the plan. |
| 77 | + Map<InferenceFunction<?>, Expression> inferenceFunctionsToResults = new HashMap<>(); |
| 78 | + |
| 79 | + // Prepare a listener that will be called when all inference functions are done. |
| 80 | + // This listener will replace the inference functions in the plan with their results. |
| 81 | + CountDownActionListener completionListener = new CountDownActionListener(inferenceFunctions.size(), listener.map(ignored -> |
| 82 | + plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f)) |
| 83 | + )); |
| 84 | + |
| 85 | + // Try to compute the result for each inference function. |
| 86 | + for (InferenceFunction<?> inferenceFunction : inferenceFunctions) { |
| 87 | + foldInferenceFunction(inferenceFunction, completionListener.map(e -> { |
| 88 | + inferenceFunctionsToResults.put(inferenceFunction, e); |
| 89 | + return null; |
| 90 | + })); |
| 91 | + } |
60 | 92 | } |
61 | 93 |
|
62 | 94 | private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) { |
|
0 commit comments