Skip to content

Commit ea3de8b

Browse files
committed
More CSV tests.
1 parent 816c410 commit ea3de8b

File tree

3 files changed

+91
-67
lines changed

3 files changed

+91
-67
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,16 @@ ROW input="Who is Victor Hugo?"
99
input:keyword | embedding:dense_vector
1010
Who is Victor Hugo? | [56.0, 50.0, 48.0]
1111
;
12+
13+
14+
text_embedding using a ROW source operator with query build using CONCAT
15+
required_capability: text_embedding_function
16+
required_capability: dense_vector_field_type
17+
18+
ROW input="Who is Victor Hugo?"
19+
| EVAL embedding = TEXT_EMBEDDING(CONCAT("Who is ", "Victor Hugo?"), "test_dense_inference")
20+
;
21+
22+
input:keyword | embedding:dense_vector
23+
Who is Victor Hugo? | [56.0, 50.0, 48.0]
24+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.compute.operator.DriverContext;
1818
import org.elasticsearch.compute.operator.EvalOperator;
1919
import org.elasticsearch.compute.operator.Operator;
20+
import org.elasticsearch.core.Releasables;
2021
import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
2122
import org.elasticsearch.indices.breaker.CircuitBreakerService;
2223
import org.elasticsearch.indices.breaker.CircuitBreakerStats;
@@ -111,36 +112,43 @@ public CircuitBreakerStats stats(String name) {
111112
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
112113

113114
// Create the inference operator for the specific function type using the provider
114-
115-
try (Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext)) {
116-
// Execute the inference operation asynchronously and handle the result
117-
// The operator will perform the actual inference call and return a page with the result
118-
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
119-
Page output = inferenceOperator.getOutput();
120-
121-
try {
122-
if (output == null) {
123-
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
124-
return;
115+
try {
116+
Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext);
117+
118+
try {
119+
// Feed the operator with a single page to trigger execution
120+
// The actual input data is already bound in the operator through expression evaluators
121+
inferenceOperator.addInput(new Page(1));
122+
123+
// Execute the inference operation asynchronously and handle the result
124+
// The operator will perform the actual inference call and return a page with the result
125+
driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
126+
Page output = inferenceOperator.getOutput();
127+
128+
try {
129+
if (output == null) {
130+
l.onFailure(new IllegalStateException("Expected output page from inference operator"));
131+
return;
132+
}
133+
134+
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
135+
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
136+
return;
137+
}
138+
139+
// Convert the operator result back to an ESQL expression (Literal)
140+
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
141+
} finally {
142+
Releasables.close(inferenceOperator);
143+
if (output != null) {
144+
output.releaseBlocks();
145+
}
125146
}
126-
127-
if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
128-
l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
129-
return;
130-
}
131-
132-
// Convert the operator result back to an ESQL expression (Literal)
133-
l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
134-
} finally {
135-
if (output != null) {
136-
output.releaseBlocks();
137-
}
138-
}
139-
}));
140-
141-
// Feed the operator with a single page to trigger execution
142-
// The actual input data is already bound in the operator through expression evaluators
143-
inferenceOperator.addInput(new Page(1));
147+
}));
148+
} catch (Exception e) {
149+
Releasables.close(inferenceOperator);
150+
listener.onFailure(e);
151+
}
144152
} catch (Exception e) {
145153
listener.onFailure(e);
146154
} finally {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference.bulk;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.ThreadedActionListener;
1112
import org.elasticsearch.client.internal.Client;
1213
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1314
import org.elasticsearch.threadpool.ThreadPool;
@@ -25,7 +26,6 @@
2526

2627
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
2728
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
28-
import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME;
2929

3030
/**
3131
* Implementation of bulk inference execution with throttling and concurrency control.
@@ -88,7 +88,7 @@ public BulkInferenceRequest poll() {
8888
public BulkInferenceRunner(Client client, int maxRunningTasks) {
8989
this.permits = new Semaphore(maxRunningTasks);
9090
this.client = client;
91-
this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME);
91+
this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
9292
}
9393

9494
/**
@@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) {
253253
executionState.finish();
254254
}
255255

256-
final ActionListener<InferenceAction.Response> inferenceResponseListener = ActionListener.runAfter(
257-
ActionListener.wrap(
258-
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
259-
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
260-
),
261-
() -> {
262-
// Release the permit we used
263-
permits.release();
264-
265-
try {
266-
synchronized (executionState) {
267-
persistPendingResponses();
268-
}
256+
final ActionListener<InferenceAction.Response> inferenceResponseListener = new ThreadedActionListener<>(
257+
executor,
258+
ActionListener.runAfter(
259+
ActionListener.wrap(
260+
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
261+
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
262+
),
263+
() -> {
264+
// Release the permit we used
265+
permits.release();
266+
267+
try {
268+
synchronized (executionState) {
269+
persistPendingResponses();
270+
}
269271

270-
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
271-
onBulkCompletion();
272-
}
272+
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
273+
onBulkCompletion();
274+
}
273275

274-
if (responseSent.get()) {
275-
// Response has already been sent
276-
// No need to continue processing this bulk.
277-
// Check if another bulk request is pending for execution.
278-
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
279-
if (nexBulkRequest != null) {
280-
executor.execute(nexBulkRequest::executePendingRequests);
276+
if (responseSent.get()) {
277+
// Response has already been sent
278+
// No need to continue processing this bulk.
279+
// Check if another bulk request is pending for execution.
280+
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
281+
if (nexBulkRequest != null) {
282+
executor.execute(nexBulkRequest::executePendingRequests);
283+
}
284+
return;
281285
}
282-
return;
283-
}
284-
if (executionState.finished() == false) {
285-
// Execute any pending requests if any
286-
if (recursionDepth > 100) {
287-
executor.execute(this::executePendingRequests);
288-
} else {
289-
this.executePendingRequests(recursionDepth + 1);
286+
if (executionState.finished() == false) {
287+
// Execute any pending requests if any
288+
if (recursionDepth > 100) {
289+
executor.execute(this::executePendingRequests);
290+
} else {
291+
this.executePendingRequests(recursionDepth + 1);
292+
}
293+
}
294+
} catch (Exception e) {
295+
if (responseSent.compareAndSet(false, true)) {
296+
completionListener.onFailure(e);
290297
}
291-
}
292-
} catch (Exception e) {
293-
if (responseSent.compareAndSet(false, true)) {
294-
completionListener.onFailure(e);
295298
}
296299
}
297-
}
300+
)
298301
);
299302

300303
// Handle null requests (edge case in some iterators)

0 commit comments

Comments
 (0)