Skip to content

Commit 48e69be

Browse files
authored
Fix an error with InferenceFunctionEvaluator when outputing a unidime… (#136210)
1 parent c8de545 commit 48e69be

File tree

4 files changed

+60
-12
lines changed

4 files changed

+60
-12
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,6 @@ tests:
615615
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeForkIT
616616
method: test {csv-spec:semantic_text.Repeat}
617617
issue: https://github.com/elastic/elasticsearch/issues/136150
618-
- class: org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluatorTests
619-
method: testFoldTextEmbeddingFunction
620-
issue: https://github.com/elastic/elasticsearch/issues/136154
621618
- class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS1UnavailableRemotesIT
622619
method: testEsqlRcs1UnavailableRemoteScenarios
623620
issue: https://github.com/elastic/elasticsearch/issues/136157

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public boolean foldable() {
109109

110110
@Override
111111
public DataType dataType() {
112-
return DataType.DENSE_VECTOR;
112+
return inputText.dataType() == DataType.NULL ? DataType.NULL : DataType.DENSE_VECTOR;
113113
}
114114

115115
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
import org.elasticsearch.xpack.esql.core.expression.Expression;
2525
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2626
import org.elasticsearch.xpack.esql.core.expression.Literal;
27+
import org.elasticsearch.xpack.esql.core.type.DataType;
2728
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
2829
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
2930
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
3031
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;
3132

33+
import java.util.List;
34+
3235
/**
3336
* Evaluator for inference functions that performs constant folding by executing inference operations
3437
* at optimization time and replacing them with their computed results.
@@ -76,6 +79,11 @@ public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) {
7679
listener.onFailure(new IllegalArgumentException("Inference function must be foldable"));
7780
return;
7881
}
82+
if (f.dataType() == DataType.NULL) {
83+
// If the function's return type is NULL, we can directly return a NULL literal without executing anything.
84+
listener.onResponse(Literal.of(f, null));
85+
return;
86+
}
7987

8088
// Set up a DriverContext for executing the inference operator.
8189
// This follows the same pattern as EvaluatorMapper but in a simplified context
@@ -129,7 +137,7 @@ public CircuitBreakerStats stats(String name) {
129137
}
130138

131139
// Convert the operator result back to an ESQL expression (Literal)
132-
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
140+
l.onResponse(Literal.of(f, processValue(f.dataType(), BlockUtils.toJavaObject(output.getBlock(0), 0))));
133141
} finally {
134142
Releasables.close(inferenceOperator);
135143
if (output != null) {
@@ -148,6 +156,14 @@ public CircuitBreakerStats stats(String name) {
148156
}
149157
}
150158

159+
private Object processValue(DataType dataType, Object value) {
160+
if (dataType == DataType.DENSE_VECTOR && value instanceof List == false) {
161+
value = List.of(value);
162+
}
163+
164+
return value;
165+
}
166+
151167
/**
152168
* Functional interface for providing inference operators.
153169
* <p>

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import static org.hamcrest.Matchers.containsString;
3333
import static org.hamcrest.Matchers.equalTo;
3434
import static org.hamcrest.Matchers.instanceOf;
35+
import static org.hamcrest.Matchers.nullValue;
3536
import static org.mockito.Mockito.doAnswer;
3637
import static org.mockito.Mockito.mock;
3738
import static org.mockito.Mockito.when;
@@ -54,8 +55,8 @@ public void testFoldTextEmbeddingFunction() throws Exception {
5455
// Create a mock TextEmbedding function
5556
TextEmbedding textEmbeddingFunction = new TextEmbedding(
5657
Source.EMPTY,
57-
Literal.keyword(Source.EMPTY, "test-model"),
58-
Literal.keyword(Source.EMPTY, "test input")
58+
Literal.keyword(Source.EMPTY, "test input"),
59+
Literal.keyword(Source.EMPTY, "test-model")
5960
);
6061

6162
// Create a mock operator that returns a result
@@ -94,12 +95,46 @@ public void testFoldTextEmbeddingFunction() throws Exception {
9495
allBreakersEmpty();
9596
}
9697

98+
public void testFoldTextEmbeddingFunctionWithNullInput() throws Exception {
99+
// Create a mock TextEmbedding function
100+
TextEmbedding textEmbeddingFunction = new TextEmbedding(Source.EMPTY, Literal.NULL, Literal.keyword(Source.EMPTY, "test-model"));
101+
102+
// Create a mock operator that returns a result
103+
Operator operator = mock(Operator.class);
104+
105+
Float[] embedding = randomArray(1, 100, Float[]::new, ESTestCase::randomFloat);
106+
107+
when(operator.getOutput()).thenAnswer(i -> {
108+
FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(1);
109+
outputBlockBuilder.appendNull();
110+
return new Page(outputBlockBuilder.build());
111+
});
112+
113+
InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator;
114+
115+
// Execute the fold operation
116+
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider);
117+
118+
AtomicReference<Expression> resultExpression = new AtomicReference<>();
119+
evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail));
120+
121+
assertBusy(() -> {
122+
assertNotNull(resultExpression.get());
123+
Literal result = as(resultExpression.get(), Literal.class);
124+
assertThat(result.dataType(), equalTo(DataType.NULL));
125+
assertThat(result.value(), nullValue());
126+
});
127+
128+
// Check all breakers are empty after the operation is executed
129+
allBreakersEmpty();
130+
}
131+
97132
public void testFoldWithNonFoldableFunction() {
98133
// A function with a non-literal argument is not foldable.
99134
TextEmbedding textEmbeddingFunction = new TextEmbedding(
100135
Source.EMPTY,
101136
mock(Attribute.class),
102-
Literal.keyword(Source.EMPTY, "test input")
137+
Literal.keyword(Source.EMPTY, "test model")
103138
);
104139

105140
InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(
@@ -118,8 +153,8 @@ public void testFoldWithNonFoldableFunction() {
118153
public void testFoldWithAsyncFailure() throws Exception {
119154
TextEmbedding textEmbeddingFunction = new TextEmbedding(
120155
Source.EMPTY,
121-
Literal.keyword(Source.EMPTY, "test-model"),
122-
Literal.keyword(Source.EMPTY, "test input")
156+
Literal.keyword(Source.EMPTY, "test input"),
157+
Literal.keyword(Source.EMPTY, "test-model")
123158
);
124159

125160
// Mock an operator that will trigger an async failure
@@ -146,8 +181,8 @@ public void testFoldWithAsyncFailure() throws Exception {
146181
public void testFoldWithNullOutputPage() throws Exception {
147182
TextEmbedding textEmbeddingFunction = new TextEmbedding(
148183
Source.EMPTY,
149-
Literal.keyword(Source.EMPTY, "test-model"),
150-
Literal.keyword(Source.EMPTY, "test input")
184+
Literal.keyword(Source.EMPTY, "test input"),
185+
Literal.keyword(Source.EMPTY, "test-model")
151186
);
152187

153188
Operator operator = mock(Operator.class);

0 commit comments

Comments
 (0)