|
8 | 8 | package org.elasticsearch.xpack.esql.inference; |
9 | 9 |
|
10 | 10 | import org.elasticsearch.action.ActionListener; |
| 11 | +import org.elasticsearch.common.breaker.CircuitBreaker; |
11 | 12 | import org.elasticsearch.common.lucene.BytesRefs; |
| 13 | +import org.elasticsearch.common.util.BigArrays; |
| 14 | +import org.elasticsearch.compute.data.BlockFactory; |
| 15 | +import org.elasticsearch.compute.data.BlockUtils; |
| 16 | +import org.elasticsearch.compute.data.Page; |
| 17 | +import org.elasticsearch.compute.operator.DriverContext; |
12 | 18 | import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; |
13 | 19 | import org.elasticsearch.compute.operator.Operator; |
| 20 | +import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; |
| 21 | +import org.elasticsearch.indices.breaker.CircuitBreakerService; |
| 22 | +import org.elasticsearch.indices.breaker.CircuitBreakerStats; |
14 | 23 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
15 | 24 | import org.elasticsearch.xpack.esql.core.expression.FoldContext; |
16 | 25 | import org.elasticsearch.xpack.esql.core.expression.Literal; |
|
19 | 28 | import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; |
20 | 29 | import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator; |
21 | 30 |
|
| 31 | +/** |
| 32 | + * Evaluator for inference functions that performs constant folding by executing inference operations |
| 33 | + * at optimization time and replacing them with their computed results. |
| 34 | + * <p> |
| 35 | + * This class is responsible for: |
| 36 | + * <ul> |
| 37 | + * <li>Setting up the necessary execution context (DriverContext, CircuitBreaker, etc.)</li> |
| 38 | + * <li>Creating and configuring appropriate inference operators for different function types</li> |
| 39 | + * <li>Executing inference operations asynchronously</li> |
| 40 | + * <li>Converting operator results back to ESQL expressions</li> |
| 41 | + * </ul> |
| 42 | + */ |
22 | 43 | public class InferenceFunctionEvaluator { |
23 | 44 |
|
24 | 45 | private final FoldContext foldContext; |
25 | 46 | private final InferenceService inferenceService; |
| 47 | + private final InferenceOperatorProvider inferenceOperatorProvider; |
26 | 48 |
|
| 49 | + /** |
| 50 | + * Creates a new inference function evaluator with the default operator provider. |
| 51 | + * |
| 52 | + * @param foldContext the fold context containing circuit breakers and evaluation settings |
| 53 | + * @param inferenceService the inference service for executing inference operations |
| 54 | + */ |
27 | 55 | public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) { |
28 | 56 | this.foldContext = foldContext; |
29 | 57 | this.inferenceService = inferenceService; |
| 58 | + this.inferenceOperatorProvider = this::createInferenceOperator; |
| 59 | + } |
| 60 | + |
| 61 | + /** |
| 62 | + * Creates a new inference function evaluator with a custom operator provider. |
| 63 | + * This constructor is primarily used for testing to inject mock operator providers. |
| 64 | + * |
| 65 | + * @param foldContext the fold context containing circuit breakers and evaluation settings |
| 66 | + * @param inferenceService the inference service for executing inference operations |
| 67 | + * @param inferenceOperatorProvider custom provider for creating inference operators |
| 68 | + */ |
| 69 | + InferenceFunctionEvaluator( |
| 70 | + FoldContext foldContext, |
| 71 | + InferenceService inferenceService, |
| 72 | + InferenceOperatorProvider inferenceOperatorProvider |
| 73 | + ) { |
| 74 | + this.foldContext = foldContext; |
| 75 | + this.inferenceService = inferenceService; |
| 76 | + this.inferenceOperatorProvider = inferenceOperatorProvider; |
30 | 77 | } |
31 | 78 |
|
32 | | - public void fold(InferenceFunction<?> f, ActionListener<Object> listener) { |
33 | | - assert f.foldable() : "Inference function must be foldable"; |
| 79 | + /** |
| 80 | + * Folds an inference function by executing it and replacing it with its computed result. |
| 81 | + * <p> |
| 82 | + * This method performs the following steps: |
| 83 | + * <ol> |
| 84 | + * <li>Validates that the function is foldable (has constant parameters)</li> |
| 85 | + * <li>Sets up a minimal execution context with appropriate circuit breakers</li> |
| 86 | + * <li>Creates and configures the appropriate inference operator</li> |
| 87 | + * <li>Executes the inference operation asynchronously</li> |
| 88 | + * <li>Converts the result to a {@link Literal} expression</li> |
| 89 | + * </ol> |
| 90 | + * |
| 91 | + * @param f the inference function to fold - must be foldable (have constant parameters) |
| 92 | + * @param listener the listener to notify when folding completes successfully or fails |
| 93 | + * @throws IllegalArgumentException if the function is not foldable |
| 94 | + */ |
| 95 | + public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) { |
| 96 | + if (f.foldable() == false) { |
| 97 | + listener.onFailure(new IllegalArgumentException("Inference function must be foldable")); |
| 98 | + return; |
| 99 | + } |
| 100 | + |
| 101 | + // Set up a DriverContext for executing the inference operator. |
| 102 | + // This follows the same pattern as EvaluatorMapper but in a simplified context |
| 103 | + // suitable for constant folding during optimization. |
| 104 | + CircuitBreaker breaker = foldContext.circuitBreakerView(f.source()); |
| 105 | + BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() { |
| 106 | + @Override |
| 107 | + public CircuitBreaker getBreaker(String name) { |
| 108 | + if (name.equals(CircuitBreaker.REQUEST) == false) { |
| 109 | + throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported"); |
| 110 | + } |
| 111 | + return breaker; |
| 112 | + } |
| 113 | + |
| 114 | + @Override |
| 115 | + public AllCircuitBreakerStats stats() { |
| 116 | + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); |
| 117 | + } |
| 118 | + |
| 119 | + @Override |
| 120 | + public CircuitBreakerStats stats(String name) { |
| 121 | + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); |
| 122 | + } |
| 123 | + }, CircuitBreaker.REQUEST).withCircuitBreaking(); |
| 124 | + |
| 125 | + DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); |
34 | 126 |
|
| 127 | + // Create the inference operator for the specific function type using the provider |
| 128 | + Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext); |
35 | 129 |
|
| 130 | + // Execute the inference operation asynchronously and handle the result |
| 131 | + // The operator will perform the actual inference call and return a page with the result |
| 132 | + driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { |
| 133 | + Page output = inferenceOperator.getOutput(); |
| 134 | + |
| 135 | + if (output == null) { |
| 136 | + l.onFailure(new IllegalStateException("Expected output page from inference operator")); |
| 137 | + return; |
| 138 | + } |
| 139 | + |
| 140 | + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { |
| 141 | + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); |
| 142 | + return; |
| 143 | + } |
| 144 | + |
| 145 | + // Convert the operator result back to an ESQL expression (Literal) |
| 146 | + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); |
| 147 | + })); |
| 148 | + |
| 149 | + // Feed the operator with a single page to trigger execution |
| 150 | + // The actual input data is already bound in the operator through expression evaluators |
| 151 | + inferenceOperator.addInput(new Page(1)); |
| 152 | + |
| 153 | + driverContext.finish(); |
36 | 154 | } |
37 | 155 |
|
38 | | - private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunction<?> f) { |
39 | | - return switch (f) { |
| 156 | + /** |
| 157 | + * Creates an inference operator for the given function type and driver context. |
| 158 | + * <p> |
| 159 | + * This method uses pattern matching to determine the correct operator factory based on |
| 160 | + * the inference function type, creates the factory, and then instantiates the operator |
| 161 | + * with the provided driver context. Each supported inference function type has its own |
| 162 | + * specialized operator implementation. |
| 163 | + * |
| 164 | + * @param f the inference function to create an operator for |
| 165 | + * @param driverContext the driver context to use for operator creation |
| 166 | + * @return an operator instance configured for the given function type |
| 167 | + * @throws IllegalArgumentException if the function type is not supported |
| 168 | + */ |
| 169 | + private Operator createInferenceOperator(InferenceFunction<?> f, DriverContext driverContext) { |
| 170 | + Operator.OperatorFactory factory = switch (f) { |
40 | 171 | case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( |
41 | 172 | inferenceService, |
42 | 173 | inferenceId(f), |
43 | 174 | expressionEvaluatorFactory(textEmbedding.inputText()) |
44 | 175 | ); |
45 | 176 | default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); |
46 | 177 | }; |
| 178 | + |
| 179 | + return factory.get(driverContext); |
47 | 180 | } |
48 | 181 |
|
| 182 | + /** |
| 183 | + * Extracts the inference endpoint ID from an inference function. |
| 184 | + * |
| 185 | + * @param f the inference function containing the inference ID |
| 186 | + * @return the inference endpoint ID as a string |
| 187 | + */ |
49 | 188 | private String inferenceId(InferenceFunction<?> f) { |
50 | 189 | return BytesRefs.toString(f.inferenceId().fold(foldContext)); |
51 | 190 | } |
52 | 191 |
|
| 192 | + /** |
| 193 | + * Creates an expression evaluator factory for a foldable expression. |
| 194 | + * <p> |
| 195 | + * This method converts a foldable expression into an evaluator factory that can be used by inference |
| 196 | + * operators. The expression is first folded to its constant value and then wrapped in a literal. |
| 197 | + * |
| 198 | + * @param e the foldable expression to create an evaluator factory for |
| 199 | + * @return an expression evaluator factory for the given expression |
| 200 | + * @throws AssertionError if the expression is not foldable (in debug builds) |
| 201 | + */ |
53 | 202 | private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) { |
54 | 203 | assert e.foldable() : "Input expression must be foldable"; |
55 | 204 | return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); |
56 | 205 | } |
| 206 | + |
| 207 | + /** |
| 208 | + * Functional interface for providing inference operators. |
| 209 | + * <p> |
| 210 | + * This interface abstracts the creation of inference operators for different function types, |
| 211 | + * allowing for easier testing and potential future extensibility. The provider is responsible |
| 212 | + * for creating an appropriate operator instance given an inference function and driver context. |
| 213 | + */ |
| 214 | + @FunctionalInterface |
| 215 | + interface InferenceOperatorProvider { |
| 216 | + /** |
| 217 | + * Creates an inference operator for the given function and driver context. |
| 218 | + * |
| 219 | + * @param f the inference function to create an operator for |
| 220 | + * @param driverContext the driver context to use for operator creation |
| 221 | + * @return an operator instance configured for the given function |
| 222 | + */ |
| 223 | + Operator getOperator(InferenceFunction<?> f, DriverContext driverContext); |
| 224 | + } |
57 | 225 | } |
0 commit comments