Skip to content

Commit 30159c5

Browse files
committed
Code simplification.
1 parent 5ea7d24 commit 30159c5

File tree

10 files changed

+135
-153
lines changed

10 files changed

+135
-153
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 13 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,27 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.compute.operator.AsyncOperator;
1414
import org.elasticsearch.compute.operator.DriverContext;
15-
import org.elasticsearch.core.Releasable;
1615
import org.elasticsearch.inference.InferenceServiceResults;
1716
import org.elasticsearch.threadpool.ThreadPool;
18-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1917
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
2018
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
19+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
2120
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2221

23-
import java.util.List;
24-
25-
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
26-
27-
public abstract class InferenceOperator<IR extends InferenceServiceResults> extends AsyncOperator<InferenceOperator.OnGoingInference<IR>> {
22+
public abstract class InferenceOperator<IR extends InferenceServiceResults> extends AsyncOperator<Page> {
2823

2924
// Move to a setting.
3025
private static final int MAX_INFERENCE_WORKER = 10;
3126
private final String inferenceId;
3227
private final BlockFactory blockFactory;
3328

34-
private final BulkInferenceExecutor bulkInferenceExecutor;
29+
private final BulkInferenceExecutor<IR, Page> bulkInferenceExecutor;
3530

3631
@SuppressWarnings("this-escape")
3732
public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, ThreadPool threadPool, String inferenceId) {
3833
super(driverContext, threadPool.getThreadContext(), MAX_INFERENCE_WORKER);
3934
this.blockFactory = driverContext.blockFactory();
40-
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig());
35+
this.bulkInferenceExecutor = new BulkInferenceExecutor<IR, Page>(inferenceRunner, threadPool, bulkExecutionConfig());
4136
this.inferenceId = inferenceId;
4237
}
4338

@@ -50,74 +45,31 @@ protected String inferenceId() {
5045
}
5146

5247
@Override
53-
protected void releaseFetchedOnAnyThread(OnGoingInference<IR> onGoingInference) {
54-
releasePageOnAnyThread(onGoingInference.inputPage);
48+
protected void releaseFetchedOnAnyThread(Page fetched) {
49+
releasePageOnAnyThread(fetched);
5550
}
5651

5752
@Override
5853
public Page getOutput() {
59-
OnGoingInference<IR> onGoingInference = fetchFromBuffer();
60-
61-
if (onGoingInference == null) {
62-
return null;
63-
}
64-
65-
try (OutputBuilder<IR> outputBuilder = outputBuilder(onGoingInference)) {
66-
onGoingInference.inferenceResponses.forEach(outputBuilder::addInferenceResult);
67-
return outputBuilder.buildOutput();
68-
}
54+
return fetchFromBuffer();
6955
}
7056

7157
@Override
72-
protected void performAsync(Page input, ActionListener<OnGoingInference<IR>> listener) {
73-
try (BulkInferenceRequestIterator requests = requests(input)) {
74-
bulkInferenceExecutor.execute(requests, listener.map(r -> onGoingInference(input, r)));
58+
protected void performAsync(Page input, ActionListener<Page> listener) {
59+
try (BulkInferenceRequestIterator requests = requests(input); BulkInferenceOutputBuilder<IR, Page> outputBuilder = outputBuilder(input)) {
60+
bulkInferenceExecutor.execute(requests, outputBuilder, listener);
7561
} catch (Exception e) {
7662
listener.onFailure(e);
63+
} finally {
64+
releasePageOnAnyThread(input);
7765
}
7866
}
7967

80-
private OnGoingInference<IR> onGoingInference(Page input, List<InferenceAction.Response> inferenceResponses) {
81-
return new OnGoingInference<>(input, inferenceResponses.stream().map(this::inferenceResults).toList());
82-
}
83-
84-
IR inferenceResults(InferenceAction.Response inferenceResponse) {
85-
InferenceServiceResults results = inferenceResponse.getResults();
86-
if (inferenceResultsClass().isInstance(results)) {
87-
return inferenceResultsClass().cast(results);
88-
}
89-
90-
throw new IllegalStateException(
91-
format(
92-
"Inference result has wrong type. Got [{}] while expecting [{}]",
93-
results.getClass().getName(),
94-
inferenceResultsClass().getName()
95-
)
96-
);
97-
}
98-
99-
protected abstract Class<IR> inferenceResultsClass();
100-
10168
protected BulkInferenceExecutionConfig bulkExecutionConfig() {
10269
return BulkInferenceExecutionConfig.DEFAULT;
10370
}
10471

10572
protected abstract BulkInferenceRequestIterator requests(Page input);
10673

107-
protected abstract OutputBuilder<IR> outputBuilder(OnGoingInference<IR> onGoingInference);
108-
109-
public abstract static class OutputBuilder<IR extends InferenceServiceResults> implements Releasable {
110-
111-
public abstract void addInferenceResult(IR inferenceResult);
112-
113-
public abstract Page buildOutput();
114-
115-
protected void releasePageOnAnyThread(Page page) {
116-
InferenceOperator.releasePageOnAnyThread(page);
117-
}
118-
}
119-
120-
public record OnGoingInference<IR extends InferenceServiceResults>(Page inputPage, List<IR> inferenceResponses) {
121-
122-
}
74+
protected abstract BulkInferenceOutputBuilder<IR, Page> outputBuilder(Page input);
12375
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,54 +9,66 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
12+
import org.elasticsearch.inference.InferenceServiceResults;
1213
import org.elasticsearch.threadpool.ThreadPool;
1314
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1415
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1516
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
1617

17-
import java.util.ArrayList;
18-
import java.util.Collections;
19-
import java.util.List;
2018
import java.util.concurrent.Executor;
2119
import java.util.concurrent.ExecutorService;
2220
import java.util.concurrent.TimeoutException;
2321
import java.util.function.Consumer;
2422

25-
public class BulkInferenceExecutor {
23+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
24+
25+
public class BulkInferenceExecutor <IR extends InferenceServiceResults, OutputType>{
2626
private static final String TASK_RUNNER_NAME = "bulk_inference_operation";
2727
private final ThrottledInferenceRunner throttledInferenceRunner;
2828

2929
public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) {
3030
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, threadPool, bulkExecutionConfig);
3131
}
3232

33-
public void execute(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) {
34-
if (requests.hasNext() == false) {
35-
listener.onResponse(List.of());
36-
return;
33+
public void execute(BulkInferenceRequestIterator requests, BulkInferenceOutputBuilder<IR, OutputType> outputBuilder, ActionListener<OutputType> listener) {
34+
try {
35+
listener.onResponse(doExecute(requests, outputBuilder));
36+
} catch (Exception e) {
37+
listener.onFailure(e);
3738
}
39+
}
3840

39-
final List<InferenceAction.Response> responses = new ArrayList<>();
41+
public OutputType doExecute(BulkInferenceRequestIterator requests, BulkInferenceOutputBuilder<IR, OutputType> outputBuilder) throws Exception {
4042
final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState();
4143

42-
try {
44+
if (requests.hasNext()) {
4345
enqueueRequests(requests, bulkExecutionState);
44-
persistsInferenceResponses(bulkExecutionState, responses::add);
45-
} catch (Exception e) {
46-
listener.onFailure(e);
46+
persistsInferenceResponses(bulkExecutionState, this.inferenceResultPersister(outputBuilder));
4747
}
4848

49-
if (bulkExecutionState.hasFailure() == false) {
50-
try {
51-
listener.onResponse(Collections.unmodifiableList(responses));
52-
return;
53-
} catch (Exception e) {
54-
listener.onFailure(e);
55-
return;
56-
}
49+
if (bulkExecutionState.hasFailure()) {
50+
throw bulkExecutionState.getFailure();
5751
}
5852

59-
listener.onFailure(bulkExecutionState.getFailure());
53+
return outputBuilder.buildOutput();
54+
}
55+
56+
private Consumer<InferenceAction.Response> inferenceResultPersister(BulkInferenceOutputBuilder<IR, OutputType> outputBuilder) {
57+
return inferenceResponse -> {
58+
InferenceServiceResults results = inferenceResponse.getResults();
59+
if (outputBuilder.inferenceResultsClass().isInstance(results)) {
60+
outputBuilder.addInferenceResults(outputBuilder.inferenceResultsClass().cast(results));
61+
return;
62+
}
63+
64+
throw new IllegalStateException(
65+
format(
66+
"Inference result has wrong type. Got [{}] while expecting [{}]",
67+
results.getClass().getName(),
68+
outputBuilder.inferenceResultsClass().getName()
69+
)
70+
);
71+
};
6072
}
6173

6274
private void enqueueRequests(BulkInferenceRequestIterator requests, BulkInferenceExecutionState bulkExecutionState) {
@@ -79,10 +91,15 @@ private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutio
7991
Long seqNo = bulkExecutionState.fetchProcessedSeqNo();
8092
retry--;
8193

82-
if (seqNo == null && retry < 0) {
83-
throw new TimeoutException("timeout waiting for inference response");
94+
if (seqNo == null) {
95+
if (retry < 0) {
96+
throw new TimeoutException("timeout waiting for inference response");
97+
}
98+
break;
8499
}
85100

101+
retry = 30;
102+
86103
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
87104

88105
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.bulk;
9+
10+
import org.elasticsearch.core.Releasable;
11+
import org.elasticsearch.inference.InferenceServiceResults;
12+
13+
public interface BulkInferenceOutputBuilder<IR extends InferenceServiceResults, OutputType> extends Releasable {
14+
void addInferenceResults(IR inferenceResults);
15+
16+
Class<IR> inferenceResultsClass();
17+
18+
OutputType buildOutput();
19+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,13 @@ public String toString() {
6464
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
6565
}
6666

67-
@Override
68-
protected Class<ChatCompletionResults> inferenceResultsClass() {
69-
return ChatCompletionResults.class;
70-
}
71-
7267
@Override
7368
protected BulkInferenceRequestIterator requests(Page inputPage) {
7469
return new CompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
7570
}
7671

7772
@Override
78-
protected CompletionOperatorOutputBuilder outputBuilder(OnGoingInference<ChatCompletionResults> onGoingInference) {
79-
try {
80-
return new CompletionOperatorOutputBuilder(
81-
blockFactory().newBytesRefBlockBuilder(onGoingInference.inputPage().getPositionCount()),
82-
onGoingInference.inputPage()
83-
);
84-
} catch (Exception e) {
85-
releaseFetchedOnAnyThread(onGoingInference);
86-
throw (e);
87-
}
73+
protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
74+
return new CompletionOperatorOutputBuilder(blockFactory().newBytesRefBlockBuilder(input.getPositionCount()), input);
8875
}
89-
9076
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
import org.elasticsearch.core.Releasables;
1515
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
1616
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
17+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
1718

1819
import java.util.concurrent.atomic.AtomicBoolean;
1920

20-
public class CompletionOperatorOutputBuilder extends InferenceOperator.OutputBuilder<ChatCompletionResults> {
21+
public class CompletionOperatorOutputBuilder implements BulkInferenceOutputBuilder<ChatCompletionResults, Page> {
2122
private final Page inputPage;
2223
private final BytesRefBlock.Builder outputBlockBuilder;
2324
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
@@ -28,15 +29,19 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder,
2829
this.outputBlockBuilder = outputBlockBuilder;
2930
}
3031

32+
@Override
33+
public Class<ChatCompletionResults> inferenceResultsClass() {
34+
return ChatCompletionResults.class;
35+
}
36+
3137
@Override
3238
public void close() {
3339
Releasables.close(outputBlockBuilder);
34-
releasePageOnAnyThread(inputPage);
3540
}
3641

3742
@Override
38-
public void addInferenceResult(ChatCompletionResults completionResults) {
39-
if (completionResults == null || completionResults.getResults().isEmpty()) {
43+
public void addInferenceResults(ChatCompletionResults completionResults) {
44+
if (completionResults == null) {
4045
outputBlockBuilder.appendNull();
4146
} else {
4247
outputBlockBuilder.beginPositionEntry();
@@ -51,12 +56,8 @@ public void addInferenceResult(ChatCompletionResults completionResults) {
5156

5257
@Override
5358
public Page buildOutput() {
54-
if (isOutputBuilt.compareAndSet(false, true)) {
55-
Block outputBlock = outputBlockBuilder.build();
56-
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
57-
return inputPage.shallowCopy().appendBlock(outputBlock);
58-
}
59-
60-
throw new IllegalStateException("buildOutput has already been called");
59+
Block outputBlock = outputBlockBuilder.build();
60+
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
61+
return inputPage.shallowCopy().appendBlock(outputBlock);
6162
}
6263
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,13 @@ public String toString() {
7979
return "RerankOperator[inference_id=[" + inferenceId() + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
8080
}
8181

82-
@Override
83-
protected Class<RankedDocsResults> inferenceResultsClass() {
84-
return RankedDocsResults.class;
85-
}
86-
8782
@Override
8883
protected RerankOperatorRequestIterator requests(Page inputPage) {
8984
return new RerankOperatorRequestIterator((BytesRefBlock) rowEncoder.eval(inputPage), inferenceId(), queryText, batchSize);
9085
}
9186

9287
@Override
93-
protected RerankOperatorOutputBuilder outputBuilder(OnGoingInference<RankedDocsResults> onGoingInference) {
94-
try {
95-
return new RerankOperatorOutputBuilder(
96-
blockFactory().newDoubleBlockBuilder(onGoingInference.inputPage().getPositionCount()),
97-
onGoingInference.inputPage(),
98-
scoreChannel
99-
);
100-
} catch (Exception e) {
101-
releaseFetchedOnAnyThread(onGoingInference);
102-
throw (e);
103-
}
88+
protected RerankOperatorOutputBuilder outputBuilder(Page input) {
89+
return new RerankOperatorOutputBuilder(blockFactory().newDoubleBlockBuilder(input.getPositionCount()), input, scoreChannel);
10490
}
10591
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.core.Releasables;
1414
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
15-
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
15+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
1616

1717
import java.util.Comparator;
1818

19-
public class RerankOperatorOutputBuilder extends InferenceOperator.OutputBuilder<RankedDocsResults> {
19+
public class RerankOperatorOutputBuilder implements BulkInferenceOutputBuilder<RankedDocsResults, Page> {
2020

2121
private final Page inputPage;
2222
private final DoubleBlock.Builder scoreBlockBuilder;
@@ -28,9 +28,13 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i
2828
this.scoreChannel = scoreChannel;
2929
}
3030

31+
@Override
32+
public Class<RankedDocsResults> inferenceResultsClass() {
33+
return RankedDocsResults.class;
34+
}
35+
3136
@Override
3237
public void close() {
33-
releasePageOnAnyThread(inputPage);
3438
Releasables.close(scoreBlockBuilder);
3539
}
3640

@@ -56,7 +60,7 @@ public Page buildOutput() {
5660
}
5761

5862
@Override
59-
public void addInferenceResult(RankedDocsResults results) {
63+
public void addInferenceResults(RankedDocsResults results) {
6064
results.getRankedDocs()
6165
.stream()
6266
.sorted(Comparator.comparingInt(RankedDocsResults.RankedDoc::index))

0 commit comments

Comments
 (0)