Skip to content

Commit 7463dad

Browse files
committed
Improved input management.
1 parent d99971f commit 7463dad

File tree

10 files changed

+53
-26
lines changed

10 files changed

+53
-26
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,8 @@ public Page getOutput() {
106106
}
107107

108108
try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) {
109-
assert ongoingInferenceResult.inputPage.getPositionCount() == ongoingInferenceResult.responses.size();
110109
for (InferenceAction.Response response : ongoingInferenceResult.responses) {
111-
try {
112-
outputBuilder.addInferenceResponse(response);
113-
} catch (IllegalArgumentException e) {
114-
throw new IllegalStateException("Invalid inference response", e);
115-
}
110+
outputBuilder.addInferenceResponse(response);
116111
}
117112
return outputBuilder.buildOutput();
118113

@@ -168,6 +163,10 @@ static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.
168163
format("Inference result has wrong type. Got [{}] while expecting [{}]", results.getClass().getName(), clazz.getName())
169164
);
170165
}
166+
167+
default void releasePageOnAnyThread(Page page) {
168+
InferenceOperator.releasePageOnAnyThread(page);
169+
}
171170
}
172171

173172
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
2020
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2121

22+
import java.util.stream.IntStream;
23+
2224
/**
2325
* {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion).
2426
* It evaluates a prompt expression for each input row, constructs inference requests, and emits the model responses as output.
@@ -48,14 +50,25 @@ public String toString() {
4850
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
4951
}
5052

53+
@Override
54+
public void addInput(Page input) {
55+
try {
56+
super.addInput(input.appendBlock(promptEvaluator.eval(input)));
57+
} catch (Exception e) {
58+
releasePageOnAnyThread(input);
59+
throw(e);
60+
}
61+
}
62+
5163
/**
5264
* Constructs the completion inference requests iterator for the given input page by evaluating the prompt expression.
5365
*
5466
* @param inputPage The input data page.
5567
*/
5668
@Override
5769
protected BulkInferenceRequestIterator requests(Page inputPage) {
58-
return new CompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
70+
int inputBlockChannel = inputPage.getBlockCount() - 1;
71+
return new CompletionOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId());
5972
}
6073

6174
/**
@@ -65,7 +78,8 @@ protected BulkInferenceRequestIterator requests(Page inputPage) {
6578
*/
6679
@Override
6780
protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
68-
return new CompletionOperatorOutputBuilder(blockFactory().newBytesRefBlockBuilder(input.getPositionCount()), input);
81+
BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount());
82+
return new CompletionOperatorOutputBuilder(outputBlockBuilder, input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()));
6983
}
7084

7185
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder,
3333
@Override
3434
public void close() {
3535
Releasables.close(outputBlockBuilder);
36+
releasePageOnAnyThread(inputPage);
3637
}
3738

3839
/**
@@ -63,7 +64,6 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
6364
bytesRefBuilder.clear();
6465
}
6566
outputBlockBuilder.endPositionEntry();
66-
6767
}
6868

6969
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ public int estimatedSize() {
115115

116116
@Override
117117
public void close() {
118-
promptBlock.allowPassingToDifferentDriver();
119-
Releasables.close(promptBlock);
118+
120119
}
121120
}
122121
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.inference.rerank;
99

10-
import org.elasticsearch.compute.data.BytesRefBlock;
1110
import org.elasticsearch.compute.data.Page;
1211
import org.elasticsearch.compute.operator.DriverContext;
1312
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -18,6 +17,8 @@
1817
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1918
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
2019

20+
import java.util.stream.IntStream;
21+
2122
/**
2223
* {@link RerankOperator} is an inference operator that compute scores for rows using a reranking model.
2324
*/
@@ -50,6 +51,16 @@ public RerankOperator(
5051
this.scoreChannel = scoreChannel;
5152
}
5253

54+
@Override
55+
public void addInput(Page input) {
56+
try {
57+
super.addInput(input.appendBlock(rowEncoder.eval(input)));
58+
} catch (Exception e) {
59+
releasePageOnAnyThread(input);
60+
throw(e);
61+
}
62+
}
63+
5364
@Override
5465
protected void doClose() {
5566
Releasables.close(rowEncoder);
@@ -65,15 +76,20 @@ public String toString() {
6576
*/
6677
@Override
6778
protected RerankOperatorRequestIterator requests(Page inputPage) {
68-
return new RerankOperatorRequestIterator((BytesRefBlock) rowEncoder.eval(inputPage), inferenceId(), queryText, batchSize);
79+
int inputBlockChannel = inputPage.getBlockCount() - 1;
80+
return new RerankOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), queryText, batchSize);
6981
}
7082

7183
/**
7284
* Returns the output builder responsible for collecting inference responses and building the output page.
7385
*/
7486
@Override
7587
protected RerankOperatorOutputBuilder outputBuilder(Page input) {
76-
return new RerankOperatorOutputBuilder(blockFactory().newDoubleBlockBuilder(input.getPositionCount()), input, scoreChannel);
88+
return new RerankOperatorOutputBuilder(
89+
blockFactory().newDoubleBlockBuilder(input.getPositionCount()),
90+
input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()),
91+
scoreChannel
92+
);
7793
}
7894

7995
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
1818

1919
import java.util.Comparator;
20+
import java.util.Iterator;
2021

2122
/**
2223
* Builds the output page for the {@link RerankOperator} by adding
@@ -38,6 +39,7 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i
3839
@Override
3940
public void close() {
4041
Releasables.close(scoreBlockBuilder);
42+
releasePageOnAnyThread(inputPage);
4143
}
4244

4345
/**
@@ -77,11 +79,10 @@ public Page buildOutput() {
7779
*/
7880
@Override
7981
public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
80-
inferenceResults(inferenceResponse).getRankedDocs()
81-
.stream()
82-
.sorted(Comparator.comparingInt(RankedDocsResults.RankedDoc::index))
83-
.mapToDouble(RankedDocsResults.RankedDoc::relevanceScore)
84-
.forEach(scoreBlockBuilder::appendDouble);
82+
Iterator<RankedDocsResults.RankedDoc> sortedRankedDocIterator = inferenceResults(inferenceResponse).getRankedDocs().stream().sorted(Comparator.comparingInt(RankedDocsResults.RankedDoc::index)).iterator();
83+
while (sortedRankedDocIterator.hasNext()) {
84+
scoreBlockBuilder.appendDouble(sortedRankedDocIterator.next().relevanceScore());
85+
}
8586
}
8687

8788
private RankedDocsResults inferenceResults(InferenceAction.Response inferenceResponse) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.common.lucene.BytesRefs;
1212
import org.elasticsearch.compute.data.BytesRefBlock;
13-
import org.elasticsearch.core.Releasables;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1615
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
@@ -81,7 +80,6 @@ private InferenceAction.Request inferenceRequest(List<String> inputs) {
8180

8281
@Override
8382
public void close() {
84-
inputBlock.allowPassingToDifferentDriver();
85-
Releasables.close(inputBlock);
83+
8684
}
8785
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ protected ThreadPool threadPool() {
7878

7979
@Override
8080
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
81-
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
81+
int minSize = Integer.min(1, size / 10);
82+
return new AbstractBlockSourceOperator(blockFactory, between(minSize, 8 * 1024)) {
8283
@Override
8384
protected int remaining() {
8485
return size - currentPosition;

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/RerankOperatorRequestIteratorTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ public void testIterateLargeInput() {
2828
}
2929

3030
private void assertIterate(int size, int batchSize) {
31-
final BytesRefBlock inputBlock = randomInputBlock(size);
3231
final String inferenceId = randomIdentifier();
3332
final String queryText = randomIdentifier();
3433

3534
try (
35+
BytesRefBlock inputBlock = randomInputBlock(size);
3636
RerankOperatorRequestIterator requestIterator = new RerankOperatorRequestIterator(inputBlock, inferenceId, queryText, batchSize)
3737
) {
3838
BytesRef scratch = new BytesRef();

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ public void testIterateLargeInput() {
2525
}
2626

2727
private void assertIterate(int size) {
28-
final BytesRefBlock inputBlock = randomInputBlock(size);
2928
final String inferenceId = randomIdentifier();
3029

31-
try (CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId)) {
30+
try (BytesRefBlock inputBlock = randomInputBlock(size); CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId)) {
3231
BytesRef scratch = new BytesRef();
3332

3433
for (int currentPos = 0; requestIterator.hasNext(); currentPos++) {

0 commit comments

Comments
 (0)