Skip to content

Commit b0dfd1c

Browse files
committed
One more refactoring.
1 parent d8bd24e commit b0dfd1c

File tree

8 files changed

+98
-156
lines changed

8 files changed

+98
-156
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.compute.operator.AsyncOperator;
1414
import org.elasticsearch.compute.operator.DriverContext;
15+
import org.elasticsearch.core.Releasable;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1719
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
1820
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
19-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
2021
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2122

22-
public abstract class InferenceOperator<InferenceResult extends InferenceServiceResults> extends AsyncOperator<Page> {
23+
import java.util.List;
24+
25+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
26+
27+
public abstract class InferenceOperator<IR extends InferenceServiceResults> extends AsyncOperator<InferenceOperator.OnGoingInference<IR>> {
2328

2429
// Move to a setting.
2530
private static final int MAX_INFERENCE_WORKER = 10;
@@ -45,37 +50,74 @@ protected String inferenceId() {
4550
}
4651

4752
@Override
48-
protected void releaseFetchedOnAnyThread(Page page) {
49-
releasePageOnAnyThread(page);
53+
protected void releaseFetchedOnAnyThread(OnGoingInference<IR> onGoingInference) {
54+
releasePageOnAnyThread(onGoingInference.inputPage);
5055
}
5156

5257
@Override
5358
public Page getOutput() {
54-
return fetchFromBuffer();
59+
OnGoingInference<IR> onGoingInference = fetchFromBuffer();
60+
61+
if (onGoingInference == null) {
62+
return null;
63+
}
64+
65+
try (OutputBuilder<IR> outputBuilder = outputBuilder(onGoingInference)) {
66+
onGoingInference.inferenceResponses.forEach(outputBuilder::addInferenceResult);
67+
return outputBuilder.buildOutput();
68+
}
5569
}
5670

5771
@Override
58-
protected void performAsync(Page input, ActionListener<Page> listener) {
59-
try (OutputBuilder<InferenceResult> outputBuilder = outputBuilder(input); BulkInferenceRequestIterator requests = requests(input)) {
60-
bulkInferenceExecutor.execute(requests, outputBuilder, listener);
72+
protected void performAsync(Page input, ActionListener<OnGoingInference<IR>> listener) {
73+
try (BulkInferenceRequestIterator requests = requests(input)) {
74+
bulkInferenceExecutor.execute(requests, listener.map(r -> onGoingInference(input, r)));
6175
} catch (Exception e) {
6276
listener.onFailure(e);
6377
}
6478
}
6579

80+
private OnGoingInference<IR> onGoingInference(Page input, List<InferenceAction.Response> inferenceResponses) {
81+
return new OnGoingInference<>(input, inferenceResponses.stream().map(this::inferenceResults).toList());
82+
}
83+
84+
IR inferenceResults(InferenceAction.Response inferenceResponse) {
85+
InferenceServiceResults results = inferenceResponse.getResults();
86+
if (inferenceResultsClass().isInstance(results)) {
87+
return inferenceResultsClass().cast(results);
88+
}
89+
90+
throw new IllegalStateException(
91+
format(
92+
"Inference result has wrong type. Got [{}] while expecting [{}]",
93+
results.getClass().getName(),
94+
inferenceResultsClass().getName()
95+
)
96+
);
97+
}
98+
99+
protected abstract Class<IR> inferenceResultsClass();
100+
66101
protected BulkInferenceExecutionConfig bulkExecutionConfig() {
67102
return BulkInferenceExecutionConfig.DEFAULT;
68103
}
69104

70105
protected abstract BulkInferenceRequestIterator requests(Page input);
71106

72-
protected abstract OutputBuilder<InferenceResult> outputBuilder(Page input);
107+
protected abstract OutputBuilder<IR> outputBuilder(OnGoingInference<IR> onGoingInference);
108+
109+
public abstract static class OutputBuilder<IR extends InferenceServiceResults> implements Releasable {
110+
111+
public abstract void addInferenceResult(IR inferenceResult);
112+
113+
public abstract Page buildOutput();
73114

74-
public abstract static class OutputBuilder<InferenceResult extends InferenceServiceResults> extends BulkInferenceOutputBuilder<
75-
InferenceResult,
76-
Page> {
77115
protected void releasePageOnAnyThread(Page page) {
78116
InferenceOperator.releasePageOnAnyThread(page);
79117
}
80118
}
119+
120+
public record OnGoingInference<IR extends InferenceServiceResults>(Page inputPage, List<IR> inferenceResponses) {
121+
122+
}
81123
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
12-
import org.elasticsearch.core.CheckedConsumer;
13-
import org.elasticsearch.inference.InferenceServiceResults;
1412
import org.elasticsearch.threadpool.ThreadPool;
1513
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1614
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1715
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
1816

17+
import java.util.ArrayList;
18+
import java.util.Collections;
19+
import java.util.List;
1920
import java.util.concurrent.Executor;
2021
import java.util.concurrent.ExecutorService;
2122
import java.util.concurrent.TimeoutException;
23+
import java.util.function.Consumer;
2224

2325
public class BulkInferenceExecutor {
2426
private static final String TASK_RUNNER_NAME = "bulk_inference_operation";
@@ -28,28 +30,25 @@ public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadP
2830
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, threadPool, bulkExecutionConfig);
2931
}
3032

31-
public <InferenceResult extends InferenceServiceResults, OutputType> void execute(
32-
BulkInferenceRequestIterator requests,
33-
BulkInferenceOutputBuilder<InferenceResult, OutputType> outputBuilder,
34-
ActionListener<OutputType> listener
35-
) {
33+
public void execute(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) {
3634
if (requests.hasNext() == false) {
37-
listener.onResponse(outputBuilder.buildOutput());
35+
listener.onResponse(List.of());
3836
return;
3937
}
4038

39+
final List<InferenceAction.Response> responses = new ArrayList<>();
4140
final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState();
4241

4342
try {
4443
enqueueRequests(requests, bulkExecutionState);
45-
persistsInferenceResponses(bulkExecutionState, outputBuilder::onInferenceResponse);
44+
persistsInferenceResponses(bulkExecutionState, responses::add);
4645
} catch (Exception e) {
4746
listener.onFailure(e);
4847
}
4948

5049
if (bulkExecutionState.hasFailure() == false) {
5150
try {
52-
listener.onResponse(outputBuilder.buildOutput());
51+
listener.onResponse(Collections.unmodifiableList(responses));
5352
return;
5453
} catch (Exception e) {
5554
listener.onFailure(e);
@@ -72,10 +71,8 @@ private void enqueueRequests(BulkInferenceRequestIterator requests, BulkInferenc
7271
}
7372
}
7473

75-
private void persistsInferenceResponses(
76-
BulkInferenceExecutionState bulkExecutionState,
77-
CheckedConsumer<InferenceAction.Response, Exception> persister
78-
) throws TimeoutException {
74+
private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutionState, Consumer<InferenceAction.Response> persister)
75+
throws TimeoutException {
7976
// TODO: retry should be from config
8077
int retry = 30;
8178
while (bulkExecutionState.getPersistedCheckpoint() < bulkExecutionState.getMaxSeqNo()) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceOutputBuilder.java

Lines changed: 0 additions & 37 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,25 @@ public String toString() {
6464
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
6565
}
6666

67+
@Override
68+
protected Class<ChatCompletionResults> inferenceResultsClass() {
69+
return ChatCompletionResults.class;
70+
}
71+
6772
@Override
6873
protected BulkInferenceRequestIterator requests(Page inputPage) {
6974
return new CompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
7075
}
7176

7277
@Override
73-
protected CompletionOperatorOutputBuilder outputBuilder(Page inputPage) {
78+
protected CompletionOperatorOutputBuilder outputBuilder(OnGoingInference<ChatCompletionResults> onGoingInference) {
7479
try {
75-
return new CompletionOperatorOutputBuilder(blockFactory().newBytesRefBlockBuilder(inputPage.getPositionCount()), inputPage);
80+
return new CompletionOperatorOutputBuilder(
81+
blockFactory().newBytesRefBlockBuilder(onGoingInference.inputPage().getPositionCount()),
82+
onGoingInference.inputPage()
83+
);
7684
} catch (Exception e) {
77-
releasePageOnAnyThread(inputPage);
85+
releaseFetchedOnAnyThread(onGoingInference);
7886
throw (e);
7987
}
8088
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void close() {
3535
}
3636

3737
@Override
38-
public void onInferenceResults(ChatCompletionResults completionResults) {
38+
public void addInferenceResult(ChatCompletionResults completionResults) {
3939
if (completionResults == null || completionResults.getResults().isEmpty()) {
4040
outputBlockBuilder.appendNull();
4141
} else {
@@ -49,11 +49,6 @@ public void onInferenceResults(ChatCompletionResults completionResults) {
4949
}
5050
}
5151

52-
@Override
53-
protected Class<ChatCompletionResults> inferenceResultsClass() {
54-
return ChatCompletionResults.class;
55-
}
56-
5752
@Override
5853
public Page buildOutput() {
5954
if (isOutputBuilt.compareAndSet(false, true)) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,26 @@ public String toString() {
7979
return "RerankOperator[inference_id=[" + inferenceId() + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
8080
}
8181

82+
@Override
83+
protected Class<RankedDocsResults> inferenceResultsClass() {
84+
return RankedDocsResults.class;
85+
}
86+
8287
@Override
8388
protected RerankOperatorRequestIterator requests(Page inputPage) {
8489
return new RerankOperatorRequestIterator((BytesRefBlock) rowEncoder.eval(inputPage), inferenceId(), queryText, batchSize);
8590
}
8691

8792
@Override
88-
protected RerankOperatorOutputBuilder outputBuilder(Page inputPage) {
93+
protected RerankOperatorOutputBuilder outputBuilder(OnGoingInference<RankedDocsResults> onGoingInference) {
8994
try {
9095
return new RerankOperatorOutputBuilder(
91-
blockFactory().newDoubleBlockBuilder(inputPage.getPositionCount()),
92-
inputPage,
96+
blockFactory().newDoubleBlockBuilder(onGoingInference.inputPage().getPositionCount()),
97+
onGoingInference.inputPage(),
9398
scoreChannel
9499
);
95100
} catch (Exception e) {
96-
releasePageOnAnyThread(inputPage);
101+
releaseFetchedOnAnyThread(onGoingInference);
97102
throw (e);
98103
}
99104
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i
2828
this.scoreChannel = scoreChannel;
2929
}
3030

31-
@Override
32-
protected Class<RankedDocsResults> inferenceResultsClass() {
33-
return RankedDocsResults.class;
34-
}
35-
3631
@Override
3732
public void close() {
3833
releasePageOnAnyThread(inputPage);
@@ -61,7 +56,7 @@ public Page buildOutput() {
6156
}
6257

6358
@Override
64-
public void onInferenceResults(RankedDocsResults results) {
59+
public void addInferenceResult(RankedDocsResults results) {
6560
results.getRankedDocs()
6661
.stream()
6762
.sorted(Comparator.comparingInt(RankedDocsResults.RankedDoc::index))

0 commit comments

Comments
 (0)