Skip to content

Commit 65759c5

Browse files
authored
[ES|QL] Fix broken CSV tests for text embedding (#136300)
* Fix bug in CSV tests when several function embedding are present in the same plan. * Fix GenerativeForkIT text-embedding tests. * lint
1 parent 7439976 commit 65759c5

File tree

5 files changed

+70
-21
lines changed

5 files changed

+70
-21
lines changed

muted-tests.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,6 @@ tests:
591591
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
592592
method: test {csv-spec:math.PowIntInt}
593593
issue: https://github.com/elastic/elasticsearch/issues/136106
594-
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
595-
method: test {csv-spec:text-embedding.Text_embedding with knn on semantic_text_dense_field}
596-
issue: https://github.com/elastic/elasticsearch/issues/136108
597594
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
598595
method: test {csv-spec:string.ContainsFail}
599596
issue: https://github.com/elastic/elasticsearch/issues/136112
@@ -606,9 +603,6 @@ tests:
606603
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
607604
method: test {csv-spec:bucket.BucketByWeekInString}
608605
issue: https://github.com/elastic/elasticsearch/issues/136136
609-
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
610-
method: test {csv-spec:text-embedding.Text_embedding with knn (inline) on semantic_text_dense_field}
611-
issue: https://github.com/elastic/elasticsearch/issues/136142
612606
- class: org.elasticsearch.versioning.ConcurrentSeqNoVersioningIT
613607
method: testSeqNoCASLinearizability
614608
issue: https://github.com/elastic/elasticsearch/issues/117249
@@ -624,15 +618,9 @@ tests:
624618
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
625619
method: test {csv-spec:math.LeastGreatestMany}
626620
issue: https://github.com/elastic/elasticsearch/issues/136161
627-
- class: org.elasticsearch.xpack.esql.qa.*.EsqlSpecIT
628-
method: test {csv-spec:text-embedding.*}
629-
issue: https://github.com/elastic/elasticsearch/issues/136090
630621
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
631622
method: test {csv-spec:categorize.Values aggs}
632623
issue: https://github.com/elastic/elasticsearch/issues/136230
633-
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
634-
method: test {csv-spec:text-embedding.Text_embedding using a row source operator}
635-
issue: https://github.com/elastic/elasticsearch/issues/136234
636624
- class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT
637625
method: test {p0=field_caps/10_basic/Field caps for number field with only doc values}
638626
issue: https://github.com/elastic/elasticsearch/issues/136244

x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ live long and prosper
6969
all we have to decide is what to do with the time that is given to us
7070
;
7171

72+
text_embedding with multiple knn queries
73+
required_capability: text_embedding_function
74+
required_capability: dense_vector_field_type_released
75+
required_capability: knn_function_v5
76+
required_capability: fork_v9
77+
required_capability: semantic_text_field_caps
78+
79+
FROM semantic_text METADATA _score, _id
80+
| WHERE KNN(semantic_text_dense_field, TEXT_EMBEDDING("be excellent to each other", "test_dense_inference")) OR KNN(semantic_text_dense_field, TEXT_EMBEDDING("live long and prosper", "test_dense_inference"))
81+
| SORT _score DESC, _id
82+
| LIMIT 10
83+
| KEEP semantic_text_field
84+
;
85+
86+
semantic_text_field:text
87+
live long and prosper
88+
be excellent to each other
89+
all we have to decide is what to do with the time that is given to us
90+
;
7291

7392
text_embedding with multiple knn queries in fork
7493
required_capability: text_embedding_function
@@ -91,5 +110,5 @@ live long and prosper | [50.0, 5
91110
live long and prosper | [45.0, 55.0, 54.0] | fork1
92111
be excellent to each other | [50.0, 57.0, 56.0] | fork2
93112
all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | fork1
94-
all we have to decide is what to do with the time that is given to us | [50.0, 57.0, 56.0] | fork2
113+
all we have to decide is what to do with the time that is given to us | [50.0, 57.0, 56.0] | fork2
95114
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.breaker.CircuitBreaker;
12+
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
1213
import org.elasticsearch.common.lucene.BytesRefs;
1314
import org.elasticsearch.common.util.BigArrays;
1415
import org.elasticsearch.compute.data.BlockFactory;
@@ -88,7 +89,7 @@ public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) {
8889
// Set up a DriverContext for executing the inference operator.
8990
// This follows the same pattern as EvaluatorMapper but in a simplified context
9091
// suitable for constant folding during optimization.
91-
CircuitBreaker breaker = foldContext.circuitBreakerView(f.source());
92+
CircuitBreaker breaker = new NoopCircuitBreaker(CircuitBreaker.REQUEST);
9293
BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() {
9394
@Override
9495
public CircuitBreaker getBreaker(String name) {
@@ -123,26 +124,28 @@ public CircuitBreakerStats stats(String name) {
123124
// Execute the inference operation asynchronously and handle the result
124125
// The operator will perform the actual inference call and return a page with the result
125126
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
126-
Page output = inferenceOperator.getOutput();
127-
128127
try {
128+
Page output = inferenceOperator.getOutput();
129+
129130
if (output == null) {
130131
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
131132
return;
132133
}
133134

135+
output.allowPassingToDifferentDriver();
136+
l = ActionListener.releaseBefore(output, l);
137+
134138
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
135139
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
136140
return;
137141
}
138142

139143
// Convert the operator result back to an ESQL expression (Literal)
140144
l.onResponse(Literal.of(f, processValue(f.dataType(), BlockUtils.toJavaObject(output.getBlock(0), 0))));
145+
} catch (Exception e) {
146+
l.onFailure(e);
141147
} finally {
142148
Releasables.close(inferenceOperator);
143-
if (output != null) {
144-
output.releaseBlocks();
145-
}
146149
}
147150
}));
148151
} catch (Exception e) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ private static String inferenceId(UnresolvedFunction f, FunctionDefinition def)
188188

189189
for (int i = 0; i < functionDescription.args().size(); i++) {
190190
EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
191+
if (i >= f.arguments().size()) {
192+
// Argument is missing. We will fail later during verifier, so just return null here.
193+
return null;
194+
}
191195

192196
if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
193197
Expression inferenceId = f.arguments().get(i);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.compute.operator.Operator;
1414
import org.elasticsearch.compute.test.ComputeTestCase;
15+
import org.elasticsearch.core.TimeValue;
1516
import org.elasticsearch.test.ESTestCase;
1617
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
19+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
1720
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1821
import org.elasticsearch.xpack.esql.core.expression.Expression;
1922
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -22,6 +25,7 @@
2225
import org.elasticsearch.xpack.esql.core.type.DataType;
2326
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
2427
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
28+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
2529
import org.junit.After;
2630
import org.junit.Before;
2731

@@ -33,6 +37,7 @@
3337
import static org.hamcrest.Matchers.equalTo;
3438
import static org.hamcrest.Matchers.instanceOf;
3539
import static org.hamcrest.Matchers.nullValue;
40+
import static org.mockito.ArgumentMatchers.any;
3641
import static org.mockito.Mockito.doAnswer;
3742
import static org.mockito.Mockito.mock;
3843
import static org.mockito.Mockito.when;
@@ -51,6 +56,7 @@ public void tearDownThreadPool() {
5156
terminate(threadPool);
5257
}
5358

59+
@SuppressWarnings("unchecked")
5460
public void testFoldTextEmbeddingFunction() throws Exception {
5561
// Create a mock TextEmbedding function
5662
TextEmbedding textEmbeddingFunction = new TextEmbedding(
@@ -62,7 +68,23 @@ public void testFoldTextEmbeddingFunction() throws Exception {
6268
// Create a mock operator that returns a result
6369
Operator operator = mock(Operator.class);
6470

65-
Float[] embedding = randomArray(1, 100, Float[]::new, ESTestCase::randomFloat);
71+
float[] embedding = randomEmbedding(between(1, 100));
72+
73+
InferenceService inferenceService = mock(InferenceService.class);
74+
BulkInferenceRunner bulkInferenceRunner = mock(BulkInferenceRunner.class);
75+
76+
doAnswer(i -> {
77+
threadPool.schedule(
78+
() -> i.getArgument(1, ActionListener.class).onResponse(List.of(inferenceResponse(embedding))),
79+
TimeValue.timeValueMillis(between(1, 10)),
80+
threadPool.generic()
81+
);
82+
83+
return null;
84+
}).when(bulkInferenceRunner).executeBulk(any(), any());
85+
when(bulkInferenceRunner.threadPool()).thenReturn(threadPool);
86+
87+
when(inferenceService.bulkInferenceRunner()).thenReturn(bulkInferenceRunner);
6688

6789
when(operator.getOutput()).thenAnswer(i -> {
6890
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(1).beginPositionEntry();
@@ -79,7 +101,7 @@ public void testFoldTextEmbeddingFunction() throws Exception {
79101
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
80102

81103
// Execute the fold operation
82-
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider);
104+
InferenceFunctionEvaluator evaluator = InferenceFunctionEvaluator.factory().create(FoldContext.small(), inferenceService);
83105

84106
AtomicReference<Expression> resultExpression = new AtomicReference<>();
85107
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail));
@@ -218,4 +240,17 @@ public void testFoldWithUnsupportedFunction() throws Exception {
218240

219241
allBreakersEmpty();
220242
}
243+
244+
private float[] randomEmbedding(int length) {
245+
float[] embedding = new float[length];
246+
for (int i = 0; i < length; i++) {
247+
embedding[i] = randomFloat();
248+
}
249+
return embedding;
250+
}
251+
252+
private InferenceAction.Response inferenceResponse(float[] embedding) {
253+
TextEmbeddingFloatResults.Embedding embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding);
254+
return new InferenceAction.Response(new TextEmbeddingFloatResults(List.of(embeddingResult)));
255+
}
221256
}

0 commit comments

Comments
 (0)