Skip to content

Commit 6d064f0

Browse files
committed
Implementing inference evaluation in the pre-optimizer.
1 parent 54e9ca6 commit 6d064f0

File tree

6 files changed

+352
-32
lines changed

6 files changed

+352
-32
lines changed

x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
7878
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
7979
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING;
80+
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION;
8081
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial;
8182
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities;
8283

@@ -205,7 +206,8 @@ protected boolean requiresInferenceEndpoint() {
205206
SEMANTIC_TEXT_FIELD_CAPS.capabilityName(),
206207
RERANK.capabilityName(),
207208
COMPLETION.capabilityName(),
208-
KNN_FUNCTION_V5.capabilityName()
209+
KNN_FUNCTION_V5.capabilityName(),
210+
TEXT_EMBEDDING_FUNCTION.capabilityName()
209211
).anyMatch(testCase.requiredCapabilities::contains);
210212
}
211213

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
text_embedding using a ROW source operator
2+
required_capability: text_embedding_function
3+
required_capability: dense_vector_field_type
4+
5+
ROW input="Who is Victor Hugo?"
6+
| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
7+
;
8+
9+
input:keyword | embedding:dense_vector
10+
Who is Victor Hugo? | [56.0, 50.0, 48.0]
11+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 172 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,18 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.common.breaker.CircuitBreaker;
1112
import org.elasticsearch.common.lucene.BytesRefs;
13+
import org.elasticsearch.common.util.BigArrays;
14+
import org.elasticsearch.compute.data.BlockFactory;
15+
import org.elasticsearch.compute.data.BlockUtils;
16+
import org.elasticsearch.compute.data.Page;
17+
import org.elasticsearch.compute.operator.DriverContext;
1218
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1319
import org.elasticsearch.compute.operator.Operator;
20+
import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
21+
import org.elasticsearch.indices.breaker.CircuitBreakerService;
22+
import org.elasticsearch.indices.breaker.CircuitBreakerStats;
1423
import org.elasticsearch.xpack.esql.core.expression.Expression;
1524
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1625
import org.elasticsearch.xpack.esql.core.expression.Literal;
@@ -19,39 +28,198 @@
1928
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
2029
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;
2130

31+
/**
32+
* Evaluator for inference functions that performs constant folding by executing inference operations
33+
* at optimization time and replacing them with their computed results.
34+
* <p>
35+
* This class is responsible for:
36+
* <ul>
37+
* <li>Setting up the necessary execution context (DriverContext, CircuitBreaker, etc.)</li>
38+
* <li>Creating and configuring appropriate inference operators for different function types</li>
39+
* <li>Executing inference operations asynchronously</li>
40+
* <li>Converting operator results back to ESQL expressions</li>
41+
* </ul>
42+
*/
2243
public class InferenceFunctionEvaluator {
2344

2445
private final FoldContext foldContext;
2546
private final InferenceService inferenceService;
47+
private final InferenceOperatorProvider inferenceOperatorProvider;
2648

49+
/**
50+
* Creates a new inference function evaluator with the default operator provider.
51+
*
52+
* @param foldContext the fold context containing circuit breakers and evaluation settings
53+
* @param inferenceService the inference service for executing inference operations
54+
*/
2755
public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) {
2856
this.foldContext = foldContext;
2957
this.inferenceService = inferenceService;
58+
this.inferenceOperatorProvider = this::createInferenceOperator;
59+
}
60+
61+
/**
62+
* Creates a new inference function evaluator with a custom operator provider.
63+
* This constructor is primarily used for testing to inject mock operator providers.
64+
*
65+
* @param foldContext the fold context containing circuit breakers and evaluation settings
66+
* @param inferenceService the inference service for executing inference operations
67+
* @param inferenceOperatorProvider custom provider for creating inference operators
68+
*/
69+
InferenceFunctionEvaluator(
70+
FoldContext foldContext,
71+
InferenceService inferenceService,
72+
InferenceOperatorProvider inferenceOperatorProvider
73+
) {
74+
this.foldContext = foldContext;
75+
this.inferenceService = inferenceService;
76+
this.inferenceOperatorProvider = inferenceOperatorProvider;
3077
}
3178

32-
public void fold(InferenceFunction<?> f, ActionListener<Object> listener) {
33-
assert f.foldable() : "Inference function must be foldable";
79+
/**
80+
* Folds an inference function by executing it and replacing it with its computed result.
81+
* <p>
82+
* This method performs the following steps:
83+
* <ol>
84+
* <li>Validates that the function is foldable (has constant parameters)</li>
85+
* <li>Sets up a minimal execution context with appropriate circuit breakers</li>
86+
* <li>Creates and configures the appropriate inference operator</li>
87+
* <li>Executes the inference operation asynchronously</li>
88+
* <li>Converts the result to a {@link Literal} expression</li>
89+
* </ol>
90+
*
91+
* @param f the inference function to fold - must be foldable (have constant parameters)
92+
* @param listener the listener to notify when folding completes successfully or fails
93+
* @throws IllegalArgumentException if the function is not foldable
94+
*/
95+
public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) {
96+
if (f.foldable() == false) {
97+
listener.onFailure(new IllegalArgumentException("Inference function must be foldable"));
98+
return;
99+
}
100+
101+
// Set up a DriverContext for executing the inference operator.
102+
// This follows the same pattern as EvaluatorMapper but in a simplified context
103+
// suitable for constant folding during optimization.
104+
CircuitBreaker breaker = foldContext.circuitBreakerView(f.source());
105+
BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() {
106+
@Override
107+
public CircuitBreaker getBreaker(String name) {
108+
if (name.equals(CircuitBreaker.REQUEST) == false) {
109+
throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported");
110+
}
111+
return breaker;
112+
}
113+
114+
@Override
115+
public AllCircuitBreakerStats stats() {
116+
throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
117+
}
118+
119+
@Override
120+
public CircuitBreakerStats stats(String name) {
121+
throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
122+
}
123+
}, CircuitBreaker.REQUEST).withCircuitBreaking();
124+
125+
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
34126

127+
// Create the inference operator for the specific function type using the provider
128+
Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext);
35129

130+
// Execute the inference operation asynchronously and handle the result
131+
// The operator will perform the actual inference call and return a page with the result
132+
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
133+
Page output = inferenceOperator.getOutput();
134+
135+
if (output == null) {
136+
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
137+
return;
138+
}
139+
140+
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
141+
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
142+
return;
143+
}
144+
145+
// Convert the operator result back to an ESQL expression (Literal)
146+
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
147+
}));
148+
149+
// Feed the operator with a single page to trigger execution
150+
// The actual input data is already bound in the operator through expression evaluators
151+
inferenceOperator.addInput(new Page(1));
152+
153+
driverContext.finish();
36154
}
37155

38-
private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunction<?> f) {
39-
return switch (f) {
156+
/**
157+
* Creates an inference operator for the given function type and driver context.
158+
* <p>
159+
* This method uses pattern matching to determine the correct operator factory based on
160+
* the inference function type, creates the factory, and then instantiates the operator
161+
* with the provided driver context. Each supported inference function type has its own
162+
* specialized operator implementation.
163+
*
164+
* @param f the inference function to create an operator for
165+
* @param driverContext the driver context to use for operator creation
166+
* @return an operator instance configured for the given function type
167+
* @throws IllegalArgumentException if the function type is not supported
168+
*/
169+
private Operator createInferenceOperator(InferenceFunction<?> f, DriverContext driverContext) {
170+
Operator.OperatorFactory factory = switch (f) {
40171
case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory(
41172
inferenceService,
42173
inferenceId(f),
43174
expressionEvaluatorFactory(textEmbedding.inputText())
44175
);
45176
default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName());
46177
};
178+
179+
return factory.get(driverContext);
47180
}
48181

182+
/**
183+
* Extracts the inference endpoint ID from an inference function.
184+
*
185+
* @param f the inference function containing the inference ID
186+
* @return the inference endpoint ID as a string
187+
*/
49188
private String inferenceId(InferenceFunction<?> f) {
50189
return BytesRefs.toString(f.inferenceId().fold(foldContext));
51190
}
52191

192+
/**
193+
* Creates an expression evaluator factory for a foldable expression.
194+
* <p>
195+
* This method converts a foldable expression into an evaluator factory that can be used by inference
196+
* operators. The expression is first folded to its constant value and then wrapped in a literal.
197+
*
198+
* @param e the foldable expression to create an evaluator factory for
199+
* @return an expression evaluator factory for the given expression
200+
* @throws AssertionError if the expression is not foldable (in debug builds)
201+
*/
53202
private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) {
54203
assert e.foldable() : "Input expression must be foldable";
55204
return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null);
56205
}
206+
207+
/**
208+
* Functional interface for providing inference operators.
209+
* <p>
210+
* This interface abstracts the creation of inference operators for different function types,
211+
* allowing for easier testing and potential future extensibility. The provider is responsible
212+
* for creating an appropriate operator instance given an inference function and driver context.
213+
*/
214+
@FunctionalInterface
215+
interface InferenceOperatorProvider {
216+
/**
217+
* Creates an inference operator for the given function and driver context.
218+
*
219+
* @param f the inference function to create an operator for
220+
* @param driverContext the driver context to use for operator creation
221+
* @return an operator instance configured for the given function
222+
*/
223+
Operator getOperator(InferenceFunction<?> f, DriverContext driverContext);
224+
}
57225
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.SubscribableListener;
12-
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
12+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.FoldInferenceFunctions;
13+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule;
1314
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1415

1516
import java.util.List;
@@ -22,15 +23,12 @@
2223
* </p>
2324
*/
2425
public class LogicalPlanPreOptimizer {
25-
26-
private final LogicalPreOptimizerContext preOptimizerContext;
26+
private final List<PreOptimizerRule> preOptimizerRules;
2727

2828
public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) {
29-
this.preOptimizerContext = preOptimizerContext;
29+
preOptimizerRules = List.of(new FoldInferenceFunctions(preOptimizerContext));
3030
}
3131

32-
private static final List<Rule> RULES = List.of();
33-
3432
/**
3533
* Pre-optimize a logical plan.
3634
*
@@ -49,28 +47,17 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
4947
}));
5048
}
5149

50+
/**
51+
* Loop over the rules and apply them sequentially to the logical plan.
52+
*
53+
* @param plan the analyzed logical plan to pre-optimize
54+
* @param listener the listener returning the pre-optimized plan when pre-optimization is complete
55+
*/
5256
private void doPreOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
53-
SubscribableListener<LogicalPlan> ruleChainListener = SubscribableListener.newSucceeded(plan);
54-
for (Rule rule : RULES) {
55-
ruleChainListener = ruleChainListener.andThen((l, p) -> rule.apply(p, l));
56-
}
57-
ruleChainListener.addListener(listener);
58-
}
59-
60-
public interface Rule {
61-
void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener);
62-
}
63-
64-
private static class FoldInferenceFunction implements Rule {
65-
private final InferenceFunctionEvaluator inferenceEvaluator;
66-
67-
private FoldInferenceFunction(LogicalPreOptimizerContext preOptimizerContext) {
68-
this.inferenceEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService());
69-
}
70-
71-
@Override
72-
public void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
73-
57+
SubscribableListener<LogicalPlan> rulesListener = SubscribableListener.newSucceeded(plan);
58+
for (PreOptimizerRule preOptimizerRule : preOptimizerRules) {
59+
rulesListener = rulesListener.andThen((l, p) -> preOptimizerRule.apply(p, l));
7460
}
61+
rulesListener.addListener(listener);
7562
}
7663
}

0 commit comments

Comments
 (0)