Skip to content

Commit 3b1a649

Browse files
committed
Better implementation of OutputBuilder.
1 parent 60c804a commit 3b1a649

File tree

5 files changed

+30
-31
lines changed

5 files changed

+30
-31
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ public Page getOutput() {
112112
outputBuilder.addInferenceResponse(response);
113113
}
114114
return outputBuilder.buildOutput();
115-
116-
} finally {
115+
} catch (Exception e) {
117116
releaseFetchedOnAnyThread(ongoingInferenceResult);
117+
throw e;
118118
}
119119
}
120120

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder,
3333
@Override
3434
public void close() {
3535
Releasables.close(outputBlockBuilder);
36-
releasePageOnAnyThread(inputPage);
3736
}
3837

3938
/**
@@ -72,13 +71,13 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
7271
}
7372

7473
/**
75-
* Builds the final output page by appending the completion output block to a shallow copy of the input page.
74+
* Builds the final output page by appending the completion output block to the input page.
7675
*/
7776
@Override
7877
public Page buildOutput() {
7978
Block outputBlock = outputBlockBuilder.build();
8079
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
81-
return inputPage.shallowCopy().appendBlock(outputBlock);
80+
return inputPage.appendBlock(outputBlock);
8281
}
8382

8483
private ChatCompletionResults inferenceResults(InferenceAction.Response inferenceResponse) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.inference.rerank;
99

10-
import org.elasticsearch.compute.data.Block;
1110
import org.elasticsearch.compute.data.DoubleBlock;
1211
import org.elasticsearch.compute.data.Page;
1312
import org.elasticsearch.core.Releasables;
@@ -18,6 +17,7 @@
1817

1918
import java.util.Comparator;
2019
import java.util.Iterator;
20+
import java.util.stream.IntStream;
2121

2222
/**
2323
* Builds the output page for the {@link RerankOperator} by adding
@@ -39,7 +39,6 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i
3939
@Override
4040
public void close() {
4141
Releasables.close(scoreBlockBuilder);
42-
releasePageOnAnyThread(inputPage);
4342
}
4443

4544
/**
@@ -49,21 +48,24 @@ public void close() {
4948
@Override
5049
public Page buildOutput() {
5150
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
52-
Block[] blocks = new Block[blockCount];
51+
Page outputPage = inputPage.appendBlock(scoreBlockBuilder.build());
52+
53+
if (scoreChannel == inputPage.getBlockCount()) {
54+
// Just need to append the block at the end
55+
// We can just return the output page we have just created
56+
return outputPage;
57+
}
5358

5459
try {
55-
for (int b = 0; b < blockCount; b++) {
56-
if (b == scoreChannel) {
57-
blocks[b] = scoreBlockBuilder.build();
58-
} else {
59-
blocks[b] = inputPage.getBlock(b);
60-
blocks[b].incRef();
61-
}
62-
}
63-
return new Page(blocks);
64-
} catch (Exception e) {
65-
Releasables.close(blocks);
66-
throw (e);
60+
// We need to project the last column to the score channel.
61+
int[] blockNapping = IntStream.range(0, inputPage.getBlockCount())
62+
.map(channel -> channel == scoreChannel ? inputPage.getBlockCount() : channel)
63+
.toArray();
64+
65+
return outputPage.projectBlocks(blockNapping);
66+
} finally {
67+
// Releasing the output page since projection is incrementing block references.
68+
releasePageOnAnyThread(outputPage);
6769
}
6870
}
6971

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323

2424
public class CompletionOperatorOutputBuilderTests extends ComputeTestCase {
2525

26-
public void testBuildSmallOutput() {
26+
public void testBuildSmallOutput() throws Exception {
2727
assertBuildOutput(between(1, 100));
2828
}
2929

30-
public void testBuildLargeOutput() {
30+
public void testBuildLargeOutput() throws Exception {
3131
assertBuildOutput(between(10_000, 100_000));
3232
}
3333

34-
private void assertBuildOutput(int size) {
34+
private void assertBuildOutput(int size) throws Exception {
3535
final Page inputPage = randomInputPage(size, between(1, 20));
3636
try (
3737
CompletionOperatorOutputBuilder outputBuilder = new CompletionOperatorOutputBuilder(
@@ -50,11 +50,9 @@ private void assertBuildOutput(int size) {
5050
assertOutputContent(outputPage.getBlock(outputPage.getBlockCount() - 1));
5151

5252
outputPage.releaseBlocks();
53-
54-
} finally {
55-
inputPage.releaseBlocks();
5653
}
5754

55+
allBreakersEmpty();
5856
}
5957

6058
private void assertOutputContent(BytesRefBlock block) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323

2424
public class RerankOperatorOutputBuilderTests extends ComputeTestCase {
2525

26-
public void testBuildSmallOutput() {
26+
public void testBuildSmallOutput() throws Exception {
2727
assertBuildOutput(between(1, 100));
2828
}
2929

30-
public void testBuildLargeOutput() {
30+
public void testBuildLargeOutput() throws Exception {
3131
assertBuildOutput(between(10_000, 100_000));
3232
}
3333

34-
private void assertBuildOutput(int size) {
34+
private void assertBuildOutput(int size) throws Exception {
3535
final Page inputPage = randomInputPage(size, between(1, 20));
3636
final int scoreChannel = randomIntBetween(0, inputPage.getBlockCount());
3737
try (
@@ -61,9 +61,9 @@ private void assertBuildOutput(int size) {
6161
outputPage.releaseBlocks();
6262
}
6363

64-
} finally {
65-
inputPage.releaseBlocks();
6664
}
65+
66+
allBreakersEmpty();
6767
}
6868

6969
private float relevanceScore(int position) {

0 commit comments

Comments
 (0)