Skip to content

Commit 34308a9

Browse files
committed
Refactored inference folding in PreOptimizer.
1 parent 99f4421 commit 34308a9

File tree

1 file changed

+33
-1
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer

1 file changed

+33
-1
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.CountDownActionListener;
1112
import org.elasticsearch.xpack.esql.core.expression.Expression;
1213
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1314
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
@@ -16,6 +17,11 @@
1617
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1718
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
1819

20+
import java.util.ArrayList;
21+
import java.util.HashMap;
22+
import java.util.List;
23+
import java.util.Map;
24+
1925
/**
2026
* The class is responsible for invoking any steps that need to be applied to the logical plan,
2127
* before this is being optimized.
@@ -56,7 +62,33 @@ private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldC
5662
}
5763

5864
private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
59-
plan.transformExpressionsUp(InferenceFunction.class, this::foldInferenceFunction, listener);
65+
// First let's collect all the inference functions
66+
List<InferenceFunction<?>> inferenceFunctions = new ArrayList<>();
67+
plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add);
68+
69+
if (inferenceFunctions.isEmpty()) {
70+
// No inference functions found. Return the original plan.
71+
listener.onResponse(plan);
72+
return;
73+
}
74+
75+
// This is a map of inference functions to their results.
76+
// We will use this map to replace the inference functions in the plan.
77+
Map<InferenceFunction<?>, Expression> inferenceFunctionsToResults = new HashMap<>();
78+
79+
// Prepare a listener that will be called when all inference functions are done.
80+
// This listener will replace the inference functions in the plan with their results.
81+
CountDownActionListener completionListener = new CountDownActionListener(inferenceFunctions.size(), listener.map(ignored ->
82+
plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f))
83+
));
84+
85+
// Try to compute the result for each inference function.
86+
for (InferenceFunction<?> inferenceFunction : inferenceFunctions) {
87+
foldInferenceFunction(inferenceFunction, completionListener.map(e -> {
88+
inferenceFunctionsToResults.put(inferenceFunction, e);
89+
return null;
90+
}));
91+
}
6092
}
6193

6294
private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) {

0 commit comments

Comments
 (0)