88package org .elasticsearch .xpack .esql .optimizer ;
99
1010import org .elasticsearch .action .ActionListener ;
11+ import org .elasticsearch .action .support .CountDownActionListener ;
12+ import org .elasticsearch .xpack .esql .core .expression .Expression ;
13+ import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
14+ import org .elasticsearch .xpack .esql .expression .function .inference .InferenceFunction ;
15+ import org .elasticsearch .xpack .esql .inference .InferenceFunctionEvaluator ;
16+ import org .elasticsearch .xpack .esql .inference .InferenceRunner ;
1117import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
18+ import org .elasticsearch .xpack .esql .plugin .TransportActionServices ;
19+
20+ import java .util .ArrayList ;
21+ import java .util .HashMap ;
22+ import java .util .List ;
23+ import java .util .Map ;
1224
1325/**
1426 * The class is responsible for invoking any steps that need to be applied to the logical plan,
1931 */
2032public class LogicalPlanPreOptimizer {
2133
22- private final LogicalPreOptimizerContext preOptimizerContext ;
34+ private final InferenceFunctionFolding inferenceFunctionFolding ;
2335
24- public LogicalPlanPreOptimizer (LogicalPreOptimizerContext preOptimizerContext ) {
25- this .preOptimizerContext = preOptimizerContext ;
36+ public LogicalPlanPreOptimizer (TransportActionServices services , LogicalPreOptimizerContext preOptimizerContext ) {
37+ this .inferenceFunctionFolding = new InferenceFunctionFolding ( services . inferenceRunner (), preOptimizerContext . foldCtx ()) ;
2638 }
2739
2840 /**
@@ -44,7 +56,53 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
4456 }
4557
4658 private void doPreOptimize (LogicalPlan plan , ActionListener <LogicalPlan > listener ) {
47- // this is where we will be executing async tasks
48- listener .onResponse (plan );
59+ inferenceFunctionFolding .foldInferenceFunctions (plan , listener );
60+ }
61+
62+ private static class InferenceFunctionFolding {
63+ private final InferenceRunner inferenceRunner ;
64+ private final FoldContext foldContext ;
65+
66+ private InferenceFunctionFolding (InferenceRunner inferenceRunner , FoldContext foldContext ) {
67+ this .inferenceRunner = inferenceRunner ;
68+ this .foldContext = foldContext ;
69+ }
70+
71+ private void foldInferenceFunctions (LogicalPlan plan , ActionListener <LogicalPlan > listener ) {
72+ // First let's collect all the inference functions
73+ List <InferenceFunction <?>> inferenceFunctions = new ArrayList <>();
74+ plan .forEachExpressionUp (InferenceFunction .class , inferenceFunctions ::add );
75+
76+ if (inferenceFunctions .isEmpty ()) {
77+ // No inference functions found. Return the original plan.
78+ listener .onResponse (plan );
79+ return ;
80+ }
81+
82+ // This is a map of inference functions to their results.
83+ // We will use this map to replace the inference functions in the plan.
84+ Map <InferenceFunction <?>, Expression > inferenceFunctionsToResults = new HashMap <>();
85+
86+ // Prepare a listener that will be called when all inference functions are done.
87+ // This listener will replace the inference functions in the plan with their results.
88+ CountDownActionListener completionListener = new CountDownActionListener (
89+ inferenceFunctions .size (),
90+ listener .map (
91+ ignored -> plan .transformExpressionsUp (InferenceFunction .class , f -> inferenceFunctionsToResults .getOrDefault (f , f ))
92+ )
93+ );
94+
95+ // Try to compute the result for each inference function.
96+ for (InferenceFunction <?> inferenceFunction : inferenceFunctions ) {
97+ foldInferenceFunction (inferenceFunction , completionListener .map (e -> {
98+ inferenceFunctionsToResults .put (inferenceFunction , e );
99+ return null ;
100+ }));
101+ }
102+ }
103+
104+ private void foldInferenceFunction (InferenceFunction <?> inferenceFunction , ActionListener <Expression > listener ) {
105+ InferenceFunctionEvaluator .get (inferenceFunction , inferenceRunner ).eval (foldContext , listener );
106+ }
49107 }
50108}
0 commit comments