Skip to content

Commit d09bce7

Browse files
committed
Bulk inference refactoring.
1 parent c8321e5 commit d09bce7

File tree

12 files changed

+378
-217
lines changed

12 files changed

+378
-217
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/BulkInferenceOperation.java

Lines changed: 0 additions & 124 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionOperator.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010
import org.apache.lucene.util.BytesRef;
1111
import org.apache.lucene.util.BytesRefBuilder;
1212
import org.elasticsearch.compute.data.Block;
13-
import org.elasticsearch.compute.data.BlockFactory;
1413
import org.elasticsearch.compute.data.BytesRefBlock;
1514
import org.elasticsearch.compute.data.Page;
1615
import org.elasticsearch.compute.operator.DriverContext;
1716
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1817
import org.elasticsearch.compute.operator.Operator;
1918
import org.elasticsearch.core.Releasables;
2019
import org.elasticsearch.inference.TaskType;
20+
import org.elasticsearch.threadpool.ThreadPool;
2121
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2222
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
23+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
24+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2325

2426
import java.util.List;
2527
import java.util.NoSuchElementException;
@@ -36,22 +38,27 @@ public String describe() {
3638

3739
@Override
3840
public Operator get(DriverContext driverContext) {
39-
return new CompletionOperator(driverContext, inferenceRunner, inferenceId, promptEvaluatorFactory.get(driverContext));
41+
return new CompletionOperator(
42+
driverContext,
43+
inferenceRunner,
44+
inferenceRunner.threadPool(),
45+
inferenceId,
46+
promptEvaluatorFactory.get(driverContext)
47+
);
4048
}
4149
}
4250

4351
private final ExpressionEvaluator promptEvaluator;
44-
private final BlockFactory blockFactory;
4552

4653
public CompletionOperator(
4754
DriverContext driverContext,
4855
InferenceRunner inferenceRunner,
56+
ThreadPool threadPool,
4957
String inferenceId,
5058
ExpressionEvaluator promptEvaluator
5159
) {
52-
super(driverContext, inferenceRunner, inferenceId);
60+
super(driverContext, inferenceRunner, threadPool, inferenceId);
5361
this.promptEvaluator = promptEvaluator;
54-
this.blockFactory = driverContext.blockFactory();
5562
}
5663

5764
@Override
@@ -65,8 +72,8 @@ public String toString() {
6572
}
6673

6774
@Override
68-
protected RequestIterator requests(Page inputPage) {
69-
return new InferenceOperator.RequestIterator() {
75+
protected BulkInferenceRequestIterator requests(Page inputPage) {
76+
return new BulkInferenceRequestIterator() {
7077
private final BytesRefBlock promptBlock = (BytesRefBlock) promptEvaluator.eval(inputPage);
7178
private BytesRef readBuffer = new BytesRef();
7279
private int currentPos = 0;
@@ -105,9 +112,9 @@ public void close() {
105112
}
106113

107114
@Override
108-
protected OutputBuilder<ChatCompletionResults> outputBuilder(Page inputPage) {
109-
return new InferenceOperator.OutputBuilder<>() {
110-
private final BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(inputPage.getPositionCount());
115+
protected BulkInferenceOutputBuilder<ChatCompletionResults, Page> outputBuilder(Page inputPage) {
116+
return new BulkInferenceOutputBuilder<>() {
117+
private final BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(inputPage.getPositionCount());
111118
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
112119

113120
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceExecutionContext.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,28 @@
77

88
package org.elasticsearch.xpack.esql.inference;
99

10+
import org.elasticsearch.core.TimeValue;
11+
1012
import java.util.concurrent.ExecutorService;
1113

1214
public class InferenceExecutionContext {
1315
private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10;
16+
private static final TimeValue DEFAULT_INFERENCE_EXECUTION_TIMEOUT = TimeValue.timeValueSeconds(10);
1417
private final InferenceRunner inferenceRunner;
1518
private final ExecutorService executorService;
1619
private final int maxConcurrentRequests;
20+
private final TimeValue inferenceExecutionTimeout;
1721

18-
private InferenceExecutionContext(InferenceRunner inferenceRunner, ExecutorService executorService, int maxConcurrentRequests) {
22+
private InferenceExecutionContext(
23+
InferenceRunner inferenceRunner,
24+
ExecutorService executorService,
25+
int maxConcurrentRequests,
26+
TimeValue inferenceExecutionTimeout
27+
) {
1928
this.inferenceRunner = inferenceRunner;
2029
this.executorService = executorService;
2130
this.maxConcurrentRequests = maxConcurrentRequests;
31+
this.inferenceExecutionTimeout = inferenceExecutionTimeout;
2232
}
2333

2434
public InferenceRunner inferenceRunner() {
@@ -33,23 +43,33 @@ public int maxConcurrentRequests() {
3343
return maxConcurrentRequests;
3444
}
3545

46+
public TimeValue inferenceExecutionTimeout() {
47+
return inferenceExecutionTimeout;
48+
}
49+
3650
public static class Builder {
3751
private final InferenceRunner inferenceRunner;
3852
private final ExecutorService executorService;
3953
private int maxConcurrentRequests = DEFAULT_MAX_CONCURRENT_REQUESTS;
54+
private TimeValue inferenceExecutionTimeout = DEFAULT_INFERENCE_EXECUTION_TIMEOUT;
4055

4156
Builder(InferenceRunner inferenceRunner, ExecutorService executorService) {
4257
this.inferenceRunner = inferenceRunner;
4358
this.executorService = executorService;
4459
}
4560

4661
public InferenceExecutionContext build() {
47-
return new InferenceExecutionContext(inferenceRunner, executorService, maxConcurrentRequests);
62+
return new InferenceExecutionContext(inferenceRunner, executorService, maxConcurrentRequests, inferenceExecutionTimeout);
4863
}
4964

5065
public Builder setMaxConcurrentRequests(int maxConcurrentRequests) {
5166
this.maxConcurrentRequests = maxConcurrentRequests;
5267
return this;
5368
}
69+
70+
public Builder setInferenceExecutionTimeout(TimeValue inferenceExecutionTimeout) {
71+
this.inferenceExecutionTimeout = inferenceExecutionTimeout;
72+
return this;
73+
}
5474
}
5575
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,38 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.compute.data.BlockFactory;
1112
import org.elasticsearch.compute.data.Page;
1213
import org.elasticsearch.compute.operator.AsyncOperator;
1314
import org.elasticsearch.compute.operator.DriverContext;
14-
import org.elasticsearch.core.CheckedConsumer;
15-
import org.elasticsearch.core.Releasable;
16-
import org.elasticsearch.core.Releasables;
1715
import org.elasticsearch.inference.InferenceServiceResults;
18-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
19-
20-
import java.util.Iterator;
21-
22-
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
16+
import org.elasticsearch.threadpool.ThreadPool;
17+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
18+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
19+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
20+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2321

2422
public abstract class InferenceOperator<InferenceResult extends InferenceServiceResults> extends AsyncOperator<Page> {
2523

2624
// Move to a setting.
2725
private static final int MAX_INFERENCE_WORKER = 10;
28-
private final InferenceRunner inferenceRunner;
2926
private final String inferenceId;
27+
private final BlockFactory blockFactory;
3028

31-
public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, String inferenceId) {
32-
super(driverContext, inferenceRunner.threadContext(), MAX_INFERENCE_WORKER);
33-
this.inferenceRunner = inferenceRunner;
29+
private final BulkInferenceExecutor<InferenceResult, Page> bulkInferenceExecutor;
30+
31+
@SuppressWarnings("this-escape")
32+
public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, ThreadPool threadPool, String inferenceId) {
33+
super(driverContext, threadPool.getThreadContext(), MAX_INFERENCE_WORKER);
34+
this.blockFactory = driverContext.blockFactory();
35+
this.bulkInferenceExecutor = new BulkInferenceExecutor<>(inferenceRunner, threadPool, bulkExecutionConfig());
3436
this.inferenceId = inferenceId;
3537
}
3638

39+
protected BlockFactory blockFactory() {
40+
return blockFactory;
41+
}
42+
3743
protected String inferenceId() {
3844
return inferenceId;
3945
}
@@ -50,52 +56,14 @@ public Page getOutput() {
5056

5157
@Override
5258
protected void performAsync(Page input, ActionListener<Page> listener) {
53-
final RequestIterator requests = requests(input);
54-
final OutputBuilder<InferenceResult> outputBuilder = outputBuilder(input);
55-
56-
new BulkInferenceOperation(requests, outputBuilder).execute(
57-
inferenceExecutionContext(),
58-
listener.delegateFailureIgnoreResponseAndWrap(l -> {
59-
l.onResponse(outputBuilder.buildOutput());
60-
Releasables.closeExpectNoException(requests, outputBuilder);
61-
})
62-
);
59+
bulkInferenceExecutor.execute(requests(input), outputBuilder(input), listener);
6360
}
6461

65-
protected InferenceExecutionContext inferenceExecutionContext() {
66-
return inferenceRunner.executionContextBuilder().build();
62+
protected BulkInferenceExecutionConfig bulkExecutionConfig() {
63+
return BulkInferenceExecutionConfig.DEFAULT;
6764
}
6865

69-
protected abstract RequestIterator requests(Page input);
70-
71-
protected abstract OutputBuilder<InferenceResult> outputBuilder(Page input);
72-
73-
public abstract static class OutputBuilder<InferenceResults extends InferenceServiceResults>
74-
implements
75-
CheckedConsumer<InferenceAction.Response, Exception>,
76-
Releasable {
77-
protected abstract Class<InferenceResults> inferenceResultsClass();
78-
79-
public abstract Page buildOutput();
80-
81-
public abstract void onInferenceResults(InferenceResults results);
82-
83-
@Override
84-
public void accept(InferenceAction.Response response) throws Exception {
85-
InferenceServiceResults results = response.getResults();
86-
if (inferenceResultsClass().isInstance(response.getResults()) == false) {
87-
throw new IllegalStateException(
88-
format(
89-
"Inference result has wrong type. Got [{}] while expecting [{}]",
90-
results.getClass().getName(),
91-
inferenceResultsClass().getName()
92-
)
93-
);
94-
}
95-
96-
onInferenceResults(inferenceResultsClass().cast(results));
97-
}
98-
}
66+
protected abstract BulkInferenceRequestIterator requests(Page input);
9967

100-
public interface RequestIterator extends Iterator<InferenceAction.Request>, Releasable {}
68+
protected abstract BulkInferenceOutputBuilder<InferenceResult, Page> outputBuilder(Page input);
10169
}

0 commit comments

Comments
 (0)