|
8 | 8 | package org.elasticsearch.xpack.esql.inference; |
9 | 9 |
|
10 | 10 | import org.elasticsearch.action.ActionListener; |
11 | | -import org.elasticsearch.common.util.concurrent.ThreadContext; |
| 11 | +import org.elasticsearch.compute.data.Page; |
12 | 12 | import org.elasticsearch.compute.operator.AsyncOperator; |
13 | 13 | import org.elasticsearch.compute.operator.DriverContext; |
14 | 14 | import org.elasticsearch.inference.InferenceServiceResults; |
| 15 | +import org.elasticsearch.inference.TaskType; |
15 | 16 | import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
| 17 | +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOperation; |
| 18 | +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder; |
| 19 | +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; |
16 | 20 |
|
17 | | -import static org.elasticsearch.common.logging.LoggerMessageFormat.format; |
18 | | - |
19 | | -abstract public class InferenceOperator<Fetched, InferenceResult extends InferenceServiceResults> extends AsyncOperator<Fetched> { |
| 21 | +public abstract class InferenceOperator<InferenceResult extends InferenceServiceResults> extends AsyncOperator<Page> { |
20 | 22 |
|
21 | 23 | // Move to a setting. |
22 | 24 | private static final int MAX_INFERENCE_WORKER = 10; |
23 | 25 | private final InferenceRunner inferenceRunner; |
24 | 26 | private final String inferenceId; |
25 | | - private final Class<InferenceResult> inferenceResultClass; |
26 | | - |
27 | | - public InferenceOperator( |
28 | | - DriverContext driverContext, |
29 | | - ThreadContext threadContext, |
30 | | - InferenceRunner inferenceRunner, |
31 | | - String inferenceId, |
32 | | - Class<InferenceResult> inferenceResultClass |
33 | | - ) { |
34 | | - super(driverContext, threadContext, MAX_INFERENCE_WORKER); |
| 27 | + |
| 28 | + public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, String inferenceId) { |
| 29 | + super(driverContext, inferenceRunner.threadContext(), MAX_INFERENCE_WORKER); |
35 | 30 | this.inferenceRunner = inferenceRunner; |
36 | 31 | this.inferenceId = inferenceId; |
37 | | - this.inferenceResultClass = inferenceResultClass; |
| 32 | + } |
38 | 33 |
|
39 | | - assert inferenceRunner.getThreadContext() != null; |
| 34 | + protected String inferenceId() { |
| 35 | + return inferenceId; |
40 | 36 | } |
41 | 37 |
|
42 | | - protected final void doInference(InferenceAction.Request inferenceRequest, ActionListener<InferenceResult> listener) { |
43 | | - inferenceRunner.doInference(inferenceRequest, listener.map(this::checkedInferenceResults)); |
| 38 | + @Override |
| 39 | + protected void releaseFetchedOnAnyThread(Page page) { |
| 40 | + releasePageOnAnyThread(page); |
44 | 41 | } |
45 | 42 |
|
46 | | - protected String inferenceId() { |
47 | | - return inferenceId; |
| 43 | + @Override |
| 44 | + public Page getOutput() { |
| 45 | + return fetchFromBuffer(); |
| 46 | + } |
| 47 | + |
| 48 | + @Override |
| 49 | + protected void performAsync(Page input, ActionListener<Page> listener) { |
| 50 | + new BulkInferenceOperation<>(bulkInferenceRequestIterator(input), bulkOutputBuilder(input)).execute(inferenceRunner, listener); |
48 | 51 | } |
49 | 52 |
|
50 | | - private InferenceResult checkedInferenceResults(InferenceAction.Response inferenceResponse) { |
51 | | - if (inferenceResultClass.isInstance(inferenceResponse.getResults())) { |
52 | | - return inferenceResultClass.cast(inferenceResponse.getResults()); |
53 | | - } |
54 | | - throw new IllegalStateException( |
55 | | - format( |
56 | | - "Inference result has wrong type. Got [{}] while expecting [{}]", |
57 | | - inferenceResponse.getResults().getClass().getName(), |
58 | | - inferenceResultClass.getName() |
59 | | - ) |
60 | | - ); |
| 53 | + protected InferenceAction.Request.Builder inferenceRequestBuilder() { |
| 54 | + return InferenceAction.Request.builder(inferenceId, taskType()); |
61 | 55 | } |
| 56 | + |
| 57 | + protected abstract TaskType taskType(); |
| 58 | + |
| 59 | + protected abstract BulkInferenceRequestIterator bulkInferenceRequestIterator(Page input); |
| 60 | + |
| 61 | + protected abstract BulkInferenceOutputBuilder<InferenceResult, Page> bulkOutputBuilder(Page input); |
62 | 62 | } |
0 commit comments