1515import org .elasticsearch .compute .data .BlockUtils ;
1616import org .elasticsearch .compute .data .Page ;
1717import org .elasticsearch .compute .operator .DriverContext ;
18- import org .elasticsearch .compute .operator .EvalOperator . ExpressionEvaluator ;
18+ import org .elasticsearch .compute .operator .EvalOperator ;
1919import org .elasticsearch .compute .operator .Operator ;
2020import org .elasticsearch .indices .breaker .AllCircuitBreakerStats ;
2121import org .elasticsearch .indices .breaker .CircuitBreakerService ;
4242 */
4343public class InferenceFunctionEvaluator {
4444
45- private final FoldContext foldContext ;
46- private final InferenceService inferenceService ;
47- private final InferenceOperatorProvider inferenceOperatorProvider ;
45+ private static final Factory FACTORY = new Factory ();
4846
49- /**
50- * Creates a new inference function evaluator with the default operator provider.
51- *
52- * @param foldContext the fold context containing circuit breakers and evaluation settings
53- * @param inferenceService the inference service for executing inference operations
54- */
55- public InferenceFunctionEvaluator (FoldContext foldContext , InferenceService inferenceService ) {
56- this .foldContext = foldContext ;
57- this .inferenceService = inferenceService ;
58- this .inferenceOperatorProvider = this ::createInferenceOperator ;
47+ public static InferenceFunctionEvaluator .Factory factory () {
48+ return FACTORY ;
5949 }
6050
51+ private final FoldContext foldContext ;
52+ private final InferenceOperatorProvider inferenceOperatorProvider ;
53+
6154 /**
6255 * Creates a new inference function evaluator with a custom operator provider.
6356 * This constructor is primarily used for testing to inject mock operator providers.
6457 *
6558 * @param foldContext the fold context containing circuit breakers and evaluation settings
66- * @param inferenceService the inference service for executing inference operations
6759 * @param inferenceOperatorProvider custom provider for creating inference operators
6860 */
69- InferenceFunctionEvaluator (
70- FoldContext foldContext ,
71- InferenceService inferenceService ,
72- InferenceOperatorProvider inferenceOperatorProvider
73- ) {
61+ InferenceFunctionEvaluator (FoldContext foldContext , InferenceOperatorProvider inferenceOperatorProvider ) {
7462 this .foldContext = foldContext ;
75- this .inferenceService = inferenceService ;
7663 this .inferenceOperatorProvider = inferenceOperatorProvider ;
7764 }
7865
@@ -90,7 +77,6 @@ public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService infe
9077 *
9178 * @param f the inference function to fold - must be foldable (have constant parameters)
9279 * @param listener the listener to notify when folding completes successfully or fails
93- * @throws IllegalArgumentException if the function is not foldable
9480 */
9581 public void fold (InferenceFunction <?> f , ActionListener <Expression > listener ) {
9682 if (f .foldable () == false ) {
@@ -125,89 +111,41 @@ public CircuitBreakerStats stats(String name) {
125111 DriverContext driverContext = new DriverContext (bigArrays , new BlockFactory (breaker , bigArrays ));
126112
127113 // Create the inference operator for the specific function type using the provider
128- Operator inferenceOperator = inferenceOperatorProvider .getOperator (f , driverContext );
129-
130- // Execute the inference operation asynchronously and handle the result
131- // The operator will perform the actual inference call and return a page with the result
132- driverContext .waitForAsyncActions (listener .delegateFailureIgnoreResponseAndWrap (l -> {
133- Page output = inferenceOperator .getOutput ();
134-
135- try {
136- if (output == null ) {
137- l .onFailure (new IllegalStateException ("Expected output page from inference operator" ));
138- return ;
139- }
140114
141- if (output .getPositionCount () != 1 || output .getBlockCount () != 1 ) {
142- l .onFailure (new IllegalStateException ("Expected a single block with a single value from inference operator" ));
143- return ;
115+ try (Operator inferenceOperator = inferenceOperatorProvider .getOperator (f , driverContext )) {
116+ // Execute the inference operation asynchronously and handle the result
117+ // The operator will perform the actual inference call and return a page with the result
118+ driverContext .waitForAsyncActions (listener .delegateFailureIgnoreResponseAndWrap (l -> {
119+ Page output = inferenceOperator .getOutput ();
120+
121+ try {
122+ if (output == null ) {
123+ l .onFailure (new IllegalStateException ("Expected output page from inference operator" ));
124+ return ;
125+ }
126+
127+ if (output .getPositionCount () != 1 || output .getBlockCount () != 1 ) {
128+ l .onFailure (new IllegalStateException ("Expected a single block with a single value from inference operator" ));
129+ return ;
130+ }
131+
132+ // Convert the operator result back to an ESQL expression (Literal)
133+ l .onResponse (Literal .of (f , BlockUtils .toJavaObject (output .getBlock (0 ), 0 )));
134+ } finally {
135+ if (output != null ) {
136+ output .releaseBlocks ();
137+ }
144138 }
145-
146- // Convert the operator result back to an ESQL expression (Literal)
147- l .onResponse (Literal .of (f , BlockUtils .toJavaObject (output .getBlock (0 ), 0 )));
148- } finally {
149- if (output != null ) {
150- output .releaseBlocks ();
151- }
152- }
153- }));
154-
155- // Feed the operator with a single page to trigger execution
156- // The actual input data is already bound in the operator through expression evaluators
157- inferenceOperator .addInput (new Page (1 ));
158-
159- driverContext .finish ();
160- }
161-
162- /**
163- * Creates an inference operator for the given function type and driver context.
164- * <p>
165- * This method uses pattern matching to determine the correct operator factory based on
166- * the inference function type, creates the factory, and then instantiates the operator
167- * with the provided driver context. Each supported inference function type has its own
168- * specialized operator implementation.
169- *
170- * @param f the inference function to create an operator for
171- * @param driverContext the driver context to use for operator creation
172- * @return an operator instance configured for the given function type
173- * @throws IllegalArgumentException if the function type is not supported
174- */
175- private Operator createInferenceOperator (InferenceFunction <?> f , DriverContext driverContext ) {
176- Operator .OperatorFactory factory = switch (f ) {
177- case TextEmbedding textEmbedding -> new TextEmbeddingOperator .Factory (
178- inferenceService ,
179- inferenceId (f ),
180- expressionEvaluatorFactory (textEmbedding .inputText ())
181- );
182- default -> throw new IllegalArgumentException ("Unknown inference function: " + f .getClass ().getName ());
183- };
184-
185- return factory .get (driverContext );
186- }
187-
188- /**
189- * Extracts the inference endpoint ID from an inference function.
190- *
191- * @param f the inference function containing the inference ID
192- * @return the inference endpoint ID as a string
193- */
194- private String inferenceId (InferenceFunction <?> f ) {
195- return BytesRefs .toString (f .inferenceId ().fold (foldContext ));
196- }
197-
198- /**
199- * Creates an expression evaluator factory for a foldable expression.
200- * <p>
201- * This method converts a foldable expression into an evaluator factory that can be used by inference
202- * operators. The expressionis first folded to its constant value and then wrapped in a literal.
203- *
204- * @param e the foldable expression to create an evaluator factory for
205- * @return an expression evaluator factory for the given expression
206- * @throws AssertionError if the expression is not foldable (in debug builds)
207- */
208- private ExpressionEvaluator .Factory expressionEvaluatorFactory (Expression e ) {
209- assert e .foldable () : "Input expression must be foldable" ;
210- return EvalMapper .toEvaluator (foldContext , Literal .of (foldContext , e ), null );
139+ }));
140+
141+ // Feed the operator with a single page to trigger execution
142+ // The actual input data is already bound in the operator through expression evaluators
143+ inferenceOperator .addInput (new Page (1 ));
144+ } catch (Exception e ) {
145+ listener .onFailure (e );
146+ } finally {
147+ driverContext .finish ();
148+ }
211149 }
212150
213151 /**
@@ -217,7 +155,6 @@ private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) {
217155 * allowing for easier testing and potential future extensibility. The provider is responsible
218156 * for creating an appropriate operator instance given an inference function and driver context.
219157 */
220- @ FunctionalInterface
221158 interface InferenceOperatorProvider {
222159 /**
223160 * Creates an inference operator for the given function and driver context.
@@ -228,4 +165,64 @@ interface InferenceOperatorProvider {
228165 */
229166 Operator getOperator (InferenceFunction <?> f , DriverContext driverContext );
230167 }
168+
169+ /**
170+ * Factory for creating {@link InferenceFunctionEvaluator} instances.
171+ */
172+ public static class Factory {
173+ private Factory () {}
174+
175+ /**
176+ * Creates a new inference function evaluator.
177+ *
178+ * @param foldContext the fold context
179+ * @param inferenceService the inference service
180+ * @return a new instance of {@link InferenceFunctionEvaluator}
181+ */
182+ public InferenceFunctionEvaluator create (FoldContext foldContext , InferenceService inferenceService ) {
183+ return new InferenceFunctionEvaluator (foldContext , createInferenceOperatorProvider (foldContext , inferenceService ));
184+ }
185+
186+ /**
187+ * Creates an {@link InferenceOperatorProvider} that can produce operators for all supported inference functions.
188+ */
189+ private InferenceOperatorProvider createInferenceOperatorProvider (FoldContext foldContext , InferenceService inferenceService ) {
190+ return (inferenceFunction , driverContext ) -> {
191+ Operator .OperatorFactory factory = switch (inferenceFunction ) {
192+ case TextEmbedding textEmbedding -> new TextEmbeddingOperator .Factory (
193+ inferenceService ,
194+ inferenceId (inferenceFunction , foldContext ),
195+ expressionEvaluatorFactory (textEmbedding .inputText (), foldContext )
196+ );
197+ default -> throw new IllegalArgumentException ("Unknown inference function: " + inferenceFunction .getClass ().getName ());
198+ };
199+
200+ return factory .get (driverContext );
201+ };
202+ }
203+
204+ /**
205+ * Extracts the inference endpoint ID from an inference function.
206+ *
207+ * @param f the inference function containing the inference ID
208+ * @return the inference endpoint ID as a string
209+ */
210+ private String inferenceId (InferenceFunction <?> f , FoldContext foldContext ) {
211+ return BytesRefs .toString (f .inferenceId ().fold (foldContext ));
212+ }
213+
214+ /**
215+ * Creates an expression evaluator factory for a foldable expression.
216+ * <p>
217+ * This method converts a foldable expression into an evaluator factory that can be used by inference
218+ * operators. The expression is first folded to its constant value and then wrapped in a literal.
219+ *
220+ * @param e the foldable expression to create an evaluator factory for
221+ * @return an expression evaluator factory for the given expression
222+ */
223+ private EvalOperator .ExpressionEvaluator .Factory expressionEvaluatorFactory (Expression e , FoldContext foldContext ) {
224+ assert e .foldable () : "Input expression must be foldable" ;
225+ return EvalMapper .toEvaluator (foldContext , Literal .of (foldContext , e ), null );
226+ }
227+ }
231228}
0 commit comments