Skip to content

Commit 5604055

Browse files
committed
Unit tests for InferenceFunctionEvaluator
1 parent 3e9d3d5 commit 5604055

File tree

3 files changed

+114
-194
lines changed

3 files changed

+114
-194
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 101 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.elasticsearch.compute.data.BlockUtils;
1616
import org.elasticsearch.compute.data.Page;
1717
import org.elasticsearch.compute.operator.DriverContext;
18-
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
18+
import org.elasticsearch.compute.operator.EvalOperator;
1919
import org.elasticsearch.compute.operator.Operator;
2020
import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
2121
import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -42,37 +42,24 @@
4242
*/
4343
public class InferenceFunctionEvaluator {
4444

45-
private final FoldContext foldContext;
46-
private final InferenceService inferenceService;
47-
private final InferenceOperatorProvider inferenceOperatorProvider;
45+
private static final Factory FACTORY = new Factory();
4846

49-
/**
50-
* Creates a new inference function evaluator with the default operator provider.
51-
*
52-
* @param foldContext the fold context containing circuit breakers and evaluation settings
53-
* @param inferenceService the inference service for executing inference operations
54-
*/
55-
public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) {
56-
this.foldContext = foldContext;
57-
this.inferenceService = inferenceService;
58-
this.inferenceOperatorProvider = this::createInferenceOperator;
47+
public static InferenceFunctionEvaluator.Factory factory() {
48+
return FACTORY;
5949
}
6050

51+
private final FoldContext foldContext;
52+
private final InferenceOperatorProvider inferenceOperatorProvider;
53+
6154
/**
6255
* Creates a new inference function evaluator with a custom operator provider.
6356
* This constructor is primarily used for testing to inject mock operator providers.
6457
*
6558
* @param foldContext the fold context containing circuit breakers and evaluation settings
66-
* @param inferenceService the inference service for executing inference operations
6759
* @param inferenceOperatorProvider custom provider for creating inference operators
6860
*/
69-
InferenceFunctionEvaluator(
70-
FoldContext foldContext,
71-
InferenceService inferenceService,
72-
InferenceOperatorProvider inferenceOperatorProvider
73-
) {
61+
InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) {
7462
this.foldContext = foldContext;
75-
this.inferenceService = inferenceService;
7663
this.inferenceOperatorProvider = inferenceOperatorProvider;
7764
}
7865

@@ -90,7 +77,6 @@ public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService infe
9077
*
9178
* @param f the inference function to fold - must be foldable (have constant parameters)
9279
* @param listener the listener to notify when folding completes successfully or fails
93-
* @throws IllegalArgumentException if the function is not foldable
9480
*/
9581
public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) {
9682
if (f.foldable() == false) {
@@ -125,89 +111,41 @@ public CircuitBreakerStats stats(String name) {
125111
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
126112

127113
// Create the inference operator for the specific function type using the provider
128-
Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext);
129-
130-
// Execute the inference operation asynchronously and handle the result
131-
// The operator will perform the actual inference call and return a page with the result
132-
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
133-
Page output = inferenceOperator.getOutput();
134-
135-
try {
136-
if (output == null) {
137-
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
138-
return;
139-
}
140114

141-
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
142-
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
143-
return;
115+
try (Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext)) {
116+
// Execute the inference operation asynchronously and handle the result
117+
// The operator will perform the actual inference call and return a page with the result
118+
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
119+
Page output = inferenceOperator.getOutput();
120+
121+
try {
122+
if (output == null) {
123+
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
124+
return;
125+
}
126+
127+
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
128+
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
129+
return;
130+
}
131+
132+
// Convert the operator result back to an ESQL expression (Literal)
133+
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
134+
} finally {
135+
if (output != null) {
136+
output.releaseBlocks();
137+
}
144138
}
145-
146-
// Convert the operator result back to an ESQL expression (Literal)
147-
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
148-
} finally {
149-
if (output != null) {
150-
output.releaseBlocks();
151-
}
152-
}
153-
}));
154-
155-
// Feed the operator with a single page to trigger execution
156-
// The actual input data is already bound in the operator through expression evaluators
157-
inferenceOperator.addInput(new Page(1));
158-
159-
driverContext.finish();
160-
}
161-
162-
/**
163-
* Creates an inference operator for the given function type and driver context.
164-
* <p>
165-
* This method uses pattern matching to determine the correct operator factory based on
166-
* the inference function type, creates the factory, and then instantiates the operator
167-
* with the provided driver context. Each supported inference function type has its own
168-
* specialized operator implementation.
169-
*
170-
* @param f the inference function to create an operator for
171-
* @param driverContext the driver context to use for operator creation
172-
* @return an operator instance configured for the given function type
173-
* @throws IllegalArgumentException if the function type is not supported
174-
*/
175-
private Operator createInferenceOperator(InferenceFunction<?> f, DriverContext driverContext) {
176-
Operator.OperatorFactory factory = switch (f) {
177-
case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory(
178-
inferenceService,
179-
inferenceId(f),
180-
expressionEvaluatorFactory(textEmbedding.inputText())
181-
);
182-
default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName());
183-
};
184-
185-
return factory.get(driverContext);
186-
}
187-
188-
/**
189-
* Extracts the inference endpoint ID from an inference function.
190-
*
191-
* @param f the inference function containing the inference ID
192-
* @return the inference endpoint ID as a string
193-
*/
194-
private String inferenceId(InferenceFunction<?> f) {
195-
return BytesRefs.toString(f.inferenceId().fold(foldContext));
196-
}
197-
198-
/**
199-
* Creates an expression evaluator factory for a foldable expression.
200-
* <p>
201-
* This method converts a foldable expression into an evaluator factory that can be used by inference
202-
* operators. The expressionis first folded to its constant value and then wrapped in a literal.
203-
*
204-
* @param e the foldable expression to create an evaluator factory for
205-
* @return an expression evaluator factory for the given expression
206-
* @throws AssertionError if the expression is not foldable (in debug builds)
207-
*/
208-
private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) {
209-
assert e.foldable() : "Input expression must be foldable";
210-
return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null);
139+
}));
140+
141+
// Feed the operator with a single page to trigger execution
142+
// The actual input data is already bound in the operator through expression evaluators
143+
inferenceOperator.addInput(new Page(1));
144+
} catch (Exception e) {
145+
listener.onFailure(e);
146+
} finally {
147+
driverContext.finish();
148+
}
211149
}
212150

213151
/**
@@ -217,7 +155,6 @@ private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) {
217155
* allowing for easier testing and potential future extensibility. The provider is responsible
218156
* for creating an appropriate operator instance given an inference function and driver context.
219157
*/
220-
@FunctionalInterface
221158
interface InferenceOperatorProvider {
222159
/**
223160
* Creates an inference operator for the given function and driver context.
@@ -228,4 +165,64 @@ interface InferenceOperatorProvider {
228165
*/
229166
Operator getOperator(InferenceFunction<?> f, DriverContext driverContext);
230167
}
168+
169+
/**
170+
* Factory for creating {@link InferenceFunctionEvaluator} instances.
171+
*/
172+
public static class Factory {
173+
private Factory() {}
174+
175+
/**
176+
* Creates a new inference function evaluator.
177+
*
178+
* @param foldContext the fold context
179+
* @param inferenceService the inference service
180+
* @return a new instance of {@link InferenceFunctionEvaluator}
181+
*/
182+
public InferenceFunctionEvaluator create(FoldContext foldContext, InferenceService inferenceService) {
183+
return new InferenceFunctionEvaluator(foldContext, createInferenceOperatorProvider(foldContext, inferenceService));
184+
}
185+
186+
/**
187+
* Creates an {@link InferenceOperatorProvider} that can produce operators for all supported inference functions.
188+
*/
189+
private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext foldContext, InferenceService inferenceService) {
190+
return (inferenceFunction, driverContext) -> {
191+
Operator.OperatorFactory factory = switch (inferenceFunction) {
192+
case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory(
193+
inferenceService,
194+
inferenceId(inferenceFunction, foldContext),
195+
expressionEvaluatorFactory(textEmbedding.inputText(), foldContext)
196+
);
197+
default -> throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName());
198+
};
199+
200+
return factory.get(driverContext);
201+
};
202+
}
203+
204+
/**
205+
* Extracts the inference endpoint ID from an inference function.
206+
*
207+
* @param f the inference function containing the inference ID
208+
* @return the inference endpoint ID as a string
209+
*/
210+
private String inferenceId(InferenceFunction<?> f, FoldContext foldContext) {
211+
return BytesRefs.toString(f.inferenceId().fold(foldContext));
212+
}
213+
214+
/**
215+
* Creates an expression evaluator factory for a foldable expression.
216+
* <p>
217+
* This method converts a foldable expression into an evaluator factory that can be used by inference
218+
* operators. The expression is first folded to its constant value and then wrapped in a literal.
219+
*
220+
* @param e the foldable expression to create an evaluator factory for
221+
* @return an expression evaluator factory for the given expression
222+
*/
223+
private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e, FoldContext foldContext) {
224+
assert e.foldable() : "Input expression must be foldable";
225+
return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null);
226+
}
227+
}
231228
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public class FoldInferenceFunctions implements PreOptimizerRule {
3737
private final InferenceFunctionEvaluator inferenceFunctionEvaluator;
3838

3939
public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) {
40-
inferenceFunctionEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService());
40+
inferenceFunctionEvaluator = InferenceFunctionEvaluator.factory()
41+
.create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService());
4142
}
4243

4344
@Override

0 commit comments

Comments
 (0)