Skip to content

Commit 3e9d3d5

Browse files
committed
Unit tests for InferenceFunctionEvaluator
1 parent 5ee2382 commit 3e9d3d5

File tree

2 files changed

+281
-11
lines changed

2 files changed

+281
-11
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,24 @@ public CircuitBreakerStats stats(String name) {
132132
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
133133
Page output = inferenceOperator.getOutput();
134134

135-
if (output == null) {
136-
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
137-
return;
138-
}
135+
try {
136+
if (output == null) {
137+
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
138+
return;
139+
}
139140

140-
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
141-
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
142-
return;
143-
}
141+
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
142+
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
143+
return;
144+
}
144145

145-
// Convert the operator result back to an ESQL expression (Literal)
146-
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
146+
// Convert the operator result back to an ESQL expression (Literal)
147+
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
148+
} finally {
149+
if (output != null) {
150+
output.releaseBlocks();
151+
}
152+
}
147153
}));
148154

149155
// Feed the operator with a single page to trigger execution
@@ -193,7 +199,7 @@ private String inferenceId(InferenceFunction<?> f) {
193199
* Creates an expression evaluator factory for a foldable expression.
194200
* <p>
195201
* This method converts a foldable expression into an evaluator factory that can be used by inference
196-
* operators. The expression is first folded to its constant value and then wrapped in a literal.
202+
* operators. The expressionis first folded to its constant value and then wrapped in a literal.
197203
*
198204
* @param e the foldable expression to create an evaluator factory for
199205
* @return an expression evaluator factory for the given expression
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.compute.data.FloatBlock;
12+
import org.elasticsearch.compute.data.Page;
13+
import org.elasticsearch.compute.operator.Operator;
14+
import org.elasticsearch.compute.test.ComputeTestCase;
15+
import org.elasticsearch.test.ESTestCase;
16+
import org.elasticsearch.threadpool.ThreadPool;
17+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
18+
import org.elasticsearch.xpack.esql.core.expression.Expression;
19+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
20+
import org.elasticsearch.xpack.esql.core.expression.Literal;
21+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
22+
import org.elasticsearch.xpack.esql.core.tree.Source;
23+
import org.elasticsearch.xpack.esql.core.type.DataType;
24+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
25+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
26+
import org.junit.After;
27+
import org.junit.Before;
28+
29+
import java.util.List;
30+
import java.util.concurrent.atomic.AtomicReference;
31+
32+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
33+
import static org.hamcrest.Matchers.containsString;
34+
import static org.hamcrest.Matchers.equalTo;
35+
import static org.hamcrest.Matchers.instanceOf;
36+
import static org.mockito.Mockito.doAnswer;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.when;
39+
40+
public class InferenceFunctionEvaluatorTests extends ComputeTestCase {
41+
42+
private ThreadPool threadPool;
43+
44+
@Before
45+
public void setupThreadPool() {
46+
this.threadPool = createThreadPool();
47+
}
48+
49+
@After
50+
public void tearDownThreadPool() {
51+
terminate(threadPool);
52+
}
53+
54+
public void testFoldTextEmbeddingFunction() throws Exception {
55+
// Create a mock TextEmbedding function
56+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
57+
Source.EMPTY,
58+
Literal.keyword(Source.EMPTY, "test-model"),
59+
Literal.keyword(Source.EMPTY, "test input")
60+
);
61+
62+
// Create a mock operator that returns a result
63+
Operator operator = mock(Operator.class);
64+
65+
Float[] embedding = randomArray(1, 100, Float[]::new, ESTestCase::randomFloat);
66+
67+
when(operator.getOutput()).thenAnswer(i -> {
68+
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(1).beginPositionEntry();
69+
70+
for (int j = 0; j < embedding.length; j++) {
71+
outputBlockBuilder.appendFloat(embedding[j]);
72+
}
73+
74+
outputBlockBuilder.endPositionEntry();
75+
76+
return new Page(outputBlockBuilder.build());
77+
});
78+
79+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
80+
81+
// Execute the fold operation
82+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
83+
FoldContext.small(),
84+
mock(InferenceService.class),
85+
inferenceOperatorProvider
86+
);
87+
88+
AtomicReference<Expression> resultExpression = new AtomicReference<>();
89+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail));
90+
91+
assertBusy(() -> {
92+
assertNotNull(resultExpression.get());
93+
Literal result = as(resultExpression.get(), Literal.class);
94+
assertThat(result.dataType(), equalTo(DataType.DENSE_VECTOR));
95+
assertThat(as(result.value(), List.class).toArray(), equalTo(embedding));
96+
});
97+
98+
// Check all breakers are empty after the operation is executed
99+
allBreakersEmpty();
100+
}
101+
102+
public void testFoldWithNonFoldableFunction() {
103+
// A function with a non-literal argument is not foldable.
104+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
105+
Source.EMPTY,
106+
mock(Attribute.class),
107+
Literal.keyword(Source.EMPTY, "test input")
108+
);
109+
110+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
111+
FoldContext.small(),
112+
mock(InferenceService.class),
113+
(f, driverContext) -> mock(Operator.class)
114+
);
115+
116+
AtomicReference<Exception> error = new AtomicReference<>();
117+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set));
118+
119+
assertNotNull(error.get());
120+
assertThat(error.get(), instanceOf(IllegalArgumentException.class));
121+
assertThat(error.get().getMessage(), equalTo("Inference function must be foldable"));
122+
}
123+
124+
public void testFoldWithAsyncFailure() throws Exception {
125+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
126+
Source.EMPTY,
127+
Literal.keyword(Source.EMPTY, "test-model"),
128+
Literal.keyword(Source.EMPTY, "test input")
129+
);
130+
131+
// Mock an operator that will trigger an async failure
132+
Operator operator = mock(Operator.class);
133+
doAnswer(invocation -> {
134+
// Simulate the operator finishing and then immediately calling the failure listener
135+
// This happens inside the `DriverContext` logic that the evaluator uses.
136+
// We can't directly access the listener, so we'll have the operator throw an exception
137+
// which should be caught and propagated to the listener.
138+
throw new RuntimeException("async failure");
139+
}).when(operator).addInput(new Page(1));
140+
141+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
142+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
143+
FoldContext.small(),
144+
mock(InferenceService.class),
145+
inferenceOperatorProvider
146+
);
147+
148+
AtomicReference<Exception> error = new AtomicReference<>();
149+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set));
150+
151+
assertBusy(() -> assertNotNull(error.get()));
152+
assertThat(error.get(), instanceOf(RuntimeException.class));
153+
assertThat(error.get().getMessage(), equalTo("async failure"));
154+
155+
allBreakersEmpty();
156+
}
157+
158+
public void testFoldWithNullOutputPage() throws Exception {
159+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
160+
Source.EMPTY,
161+
Literal.keyword(Source.EMPTY, "test-model"),
162+
Literal.keyword(Source.EMPTY, "test input")
163+
);
164+
165+
Operator operator = mock(Operator.class);
166+
when(operator.getOutput()).thenReturn(null);
167+
168+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
169+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
170+
FoldContext.small(),
171+
mock(InferenceService.class),
172+
inferenceOperatorProvider
173+
);
174+
175+
AtomicReference<Exception> error = new AtomicReference<>();
176+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set));
177+
178+
assertBusy(() -> assertNotNull(error.get()));
179+
assertThat(error.get(), instanceOf(IllegalStateException.class));
180+
assertThat(error.get().getMessage(), equalTo("Expected output page from inference operator"));
181+
182+
allBreakersEmpty();
183+
}
184+
185+
public void testFoldWithMultiPositionOutputPage() throws Exception {
186+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
187+
Source.EMPTY,
188+
Literal.keyword(Source.EMPTY, "test-model"),
189+
Literal.keyword(Source.EMPTY, "test input")
190+
);
191+
192+
Operator operator = mock(Operator.class);
193+
// Output page should have exactly one position for constant folding
194+
when(operator.getOutput()).thenReturn(new Page(blockFactory().newFloatBlockBuilder(2).build()));
195+
196+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
197+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
198+
FoldContext.small(),
199+
mock(InferenceService.class),
200+
inferenceOperatorProvider
201+
);
202+
203+
AtomicReference<Exception> error = new AtomicReference<>();
204+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set));
205+
206+
assertBusy(() -> assertNotNull(error.get()));
207+
assertThat(error.get(), instanceOf(IllegalStateException.class));
208+
assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator"));
209+
210+
allBreakersEmpty();
211+
}
212+
213+
public void testFoldWithMultiBlockOutputPage() throws Exception {
214+
TextEmbedding textEmbeddingFunction = new TextEmbedding(
215+
Source.EMPTY,
216+
Literal.keyword(Source.EMPTY, "test-model"),
217+
Literal.keyword(Source.EMPTY, "test input")
218+
);
219+
220+
Operator operator = mock(Operator.class);
221+
// Output page should have exactly one block for constant folding
222+
when(operator.getOutput()).thenReturn(
223+
new Page(blockFactory().newFloatBlockBuilder(1).build(), blockFactory().newFloatBlockBuilder(1).build())
224+
);
225+
226+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
227+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
228+
FoldContext.small(),
229+
mock(InferenceService.class),
230+
inferenceOperatorProvider
231+
);
232+
233+
AtomicReference<Exception> error = new AtomicReference<>();
234+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set));
235+
236+
assertBusy(() -> assertNotNull(error.get()));
237+
assertThat(error.get(), instanceOf(IllegalStateException.class));
238+
assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator"));
239+
240+
allBreakersEmpty();
241+
}
242+
243+
public void testFoldWithUnsupportedFunction() throws Exception {
244+
Function unsupported = mock(Function.class);
245+
when(unsupported.foldable()).thenReturn(true);
246+
247+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
248+
FoldContext.small(),
249+
mock(InferenceService.class),
250+
(f, driverContext) -> {
251+
throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName());
252+
}
253+
);
254+
255+
AtomicReference<Exception> error = new AtomicReference<>();
256+
evaluator.fold((InferenceFunction<?>) unsupported, ActionListener.wrap(r -> fail("should have failed"), error::set));
257+
258+
assertNotNull(error.get());
259+
assertThat(error.get(), instanceOf(IllegalArgumentException.class));
260+
assertThat(error.get().getMessage(), containsString("Unknown inference function"));
261+
262+
allBreakersEmpty();
263+
}
264+
}

0 commit comments

Comments
 (0)