Skip to content

Commit 8f9921f

Browse files
committed
Fix a memory leak in the InferenceOperator
1 parent 8db5cd2 commit 8f9921f

File tree

12 files changed

+601
-398
lines changed

12 files changed

+601
-398
lines changed

muted-tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,9 @@ tests:
315315
- class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
316316
method: testSearchWithRandomDisconnects
317317
issue: https://github.com/elastic/elasticsearch/issues/122707
318-
- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
319-
method: testSimpleCircuitBreaking
320-
issue: https://github.com/elastic/elasticsearch/issues/124337
318+
#- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
319+
# method: testSimpleCircuitBreaking
320+
# issue: https://github.com/elastic/elasticsearch/issues/124337
321321
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
322322
method: testSchedulerCloseWaitsForRunningMerge
323323
issue: https://github.com/elastic/elasticsearch/issues/125236

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/RandomBlock.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,24 @@ public static RandomBlock randomBlock(
7575
int maxValuesPerPosition,
7676
int minDupsPerPosition,
7777
int maxDupsPerPosition
78+
) {
79+
return randomBlock(blockFactory, elementType, positionCount, ESTestCase.randomBoolean(), nullAllowed, minValuesPerPosition, maxValuesPerPosition, minDupsPerPosition, maxDupsPerPosition);
80+
}
81+
82+
public static RandomBlock randomBlock(
83+
BlockFactory blockFactory,
84+
ElementType elementType,
85+
int positionCount,
86+
boolean bytesRefFromPoints,
87+
boolean nullAllowed,
88+
int minValuesPerPosition,
89+
int maxValuesPerPosition,
90+
int minDupsPerPosition,
91+
int maxDupsPerPosition
7892
) {
7993
List<List<Object>> values = new ArrayList<>();
8094
Block.MvOrdering mvOrdering = Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING;
8195
try (var builder = elementType.newBlockBuilder(positionCount, blockFactory)) {
82-
boolean bytesRefFromPoints = ESTestCase.randomBoolean();
8396
Supplier<Point> pointSupplier = ESTestCase.randomBoolean() ? GeometryTestUtils::randomPoint : ShapeTestUtils::randomPoint;
8497
for (int p = 0; p < positionCount; p++) {
8598
if (elementType == ElementType.NULL) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/CompletionOperator.java

Lines changed: 87 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.util.List;
2727
import java.util.NoSuchElementException;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2829

2930
public class CompletionOperator extends InferenceOperator<ChatCompletionResults> {
3031

@@ -33,7 +34,7 @@ public record Factory(InferenceRunner inferenceRunner, String inferenceId, Expre
3334
OperatorFactory {
3435
@Override
3536
public String describe() {
36-
return "Completion[inference_id=[" + inferenceId + "]]";
37+
return "CompletionOperator[inference_id=[" + inferenceId + "]]";
3738
}
3839

3940
@Override
@@ -73,82 +74,102 @@ public String toString() {
7374

7475
@Override
7576
protected BulkInferenceRequestIterator requests(Page inputPage) {
76-
return new BulkInferenceRequestIterator() {
77-
private final BytesRefBlock promptBlock = (BytesRefBlock) promptEvaluator.eval(inputPage);
78-
private BytesRef readBuffer = new BytesRef();
79-
private int currentPos = 0;
80-
81-
@Override
82-
public boolean hasNext() {
83-
return currentPos < promptBlock.getPositionCount();
84-
}
85-
86-
@Override
87-
public InferenceAction.Request next() {
88-
if (hasNext() == false) {
89-
throw new NoSuchElementException();
77+
final BytesRefBlock promptBlock = (BytesRefBlock) promptEvaluator.eval(inputPage);
78+
try {
79+
return new BulkInferenceRequestIterator() {
80+
private int currentPos = 0;
81+
BytesRef readBuffer = new BytesRef();
82+
@Override
83+
public boolean hasNext() {
84+
return currentPos < promptBlock.getPositionCount();
9085
}
91-
int pos = currentPos++;
9286

93-
if (promptBlock.isNull(pos)) {
94-
return null;
95-
}
87+
@Override
88+
public InferenceAction.Request next() {
89+
if (hasNext() == false) {
90+
throw new NoSuchElementException();
91+
}
92+
int pos = currentPos++;
9693

97-
StringBuilder promptBuilder = new StringBuilder();
98-
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
99-
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
100-
promptBuilder.append(readBuffer.utf8ToString()).append("\n");
101-
}
94+
if (promptBlock.isNull(pos)) {
95+
return null;
96+
}
97+
98+
StringBuilder promptBuilder = new StringBuilder();
99+
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
100+
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
101+
promptBuilder.append(readBuffer.utf8ToString()).append("\n");
102+
}
102103

103-
return inferenceRequest(promptBuilder.toString());
104-
}
104+
return inferenceRequest(promptBuilder.toString());
105+
}
105106

106-
@Override
107-
public void close() {
108-
promptBlock.allowPassingToDifferentDriver();
109-
Releasables.closeExpectNoException(promptBlock);
110-
}
111-
};
107+
@Override
108+
public void close() {
109+
promptBlock.allowPassingToDifferentDriver();
110+
Releasables.closeExpectNoException(promptBlock);
111+
}
112+
};
113+
} catch (Exception e) {
114+
promptBlock.allowPassingToDifferentDriver();
115+
Releasables.closeExpectNoException(promptBlock);
116+
throw(e);
117+
}
112118
}
113119

114120
@Override
115121
protected BulkInferenceOutputBuilder<ChatCompletionResults, Page> outputBuilder(Page inputPage) {
116-
return new BulkInferenceOutputBuilder<>() {
117-
private final BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(inputPage.getPositionCount());
118-
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
119-
120-
@Override
121-
public void close() {
122-
Releasables.closeExpectNoException(outputBlockBuilder);
123-
}
124-
125-
@Override
126-
public void onInferenceResults(ChatCompletionResults completionResults) {
127-
if (completionResults == null || completionResults.getResults().isEmpty()) {
128-
outputBlockBuilder.appendNull();
129-
} else {
130-
outputBlockBuilder.beginPositionEntry();
131-
for (ChatCompletionResults.Result rankedDocsResult : completionResults.getResults()) {
132-
bytesRefBuilder.copyChars(rankedDocsResult.content());
133-
outputBlockBuilder.appendBytesRef(bytesRefBuilder.get());
134-
bytesRefBuilder.clear();
122+
final BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(inputPage.getPositionCount());
123+
final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
124+
final AtomicBoolean isOutputBuilt = new AtomicBoolean(false);
125+
126+
try {
127+
return new BulkInferenceOutputBuilder<>() {
128+
@Override
129+
public void close() {
130+
if (isOutputBuilt.get() == false) {
131+
releasePageOnAnyThread(inputPage);
132+
}
133+
134+
Releasables.closeExpectNoException(outputBlockBuilder);
135+
}
136+
137+
@Override
138+
public void onInferenceResults(ChatCompletionResults completionResults) {
139+
if (completionResults == null || completionResults.getResults().isEmpty()) {
140+
outputBlockBuilder.appendNull();
141+
} else {
142+
outputBlockBuilder.beginPositionEntry();
143+
for (ChatCompletionResults.Result rankedDocsResult : completionResults.getResults()) {
144+
bytesRefBuilder.copyChars(rankedDocsResult.content());
145+
outputBlockBuilder.appendBytesRef(bytesRefBuilder.get());
146+
bytesRefBuilder.clear();
147+
}
148+
outputBlockBuilder.endPositionEntry();
149+
}
150+
}
151+
152+
@Override
153+
protected Class<ChatCompletionResults> inferenceResultsClass() {
154+
return ChatCompletionResults.class;
155+
}
156+
157+
@Override
158+
public Page buildOutput() {
159+
if (isOutputBuilt.compareAndSet(false, true)) {
160+
Block outputBlock = outputBlockBuilder.build();
161+
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
162+
return inputPage.appendBlock(outputBlock);
135163
}
136-
outputBlockBuilder.endPositionEntry();
164+
165+
throw new IllegalStateException("buildOutput has already been called");
137166
}
138-
}
139-
140-
@Override
141-
protected Class<ChatCompletionResults> inferenceResultsClass() {
142-
return ChatCompletionResults.class;
143-
}
144-
145-
@Override
146-
public Page buildOutput() {
147-
Block outputBlock = outputBlockBuilder.build();
148-
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
149-
return inputPage.appendBlock(outputBlock);
150-
}
151-
};
167+
};
168+
} catch (Exception e) {
169+
releasePageOnAnyThread(inputPage);
170+
Releasables.closeExpectNoException(outputBlockBuilder);
171+
throw(e);
172+
}
152173
}
153174

154175
private InferenceAction.Request inferenceRequest(String prompt) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,17 @@ public Page getOutput() {
5656

5757
@Override
5858
protected void performAsync(Page input, ActionListener<Page> listener) {
59-
bulkInferenceExecutor.execute(requests(input), outputBuilder(input), listener);
59+
try {
60+
BulkInferenceOutputBuilder<InferenceResult, Page> outputBuilder = outputBuilder(input);
61+
listener = ActionListener.releaseBefore(outputBuilder, listener);
62+
63+
BulkInferenceRequestIterator requests = requests(input);
64+
listener = ActionListener.releaseBefore(requests, listener);
65+
66+
bulkInferenceExecutor.execute(requests, outputBuilder, listener);
67+
} catch (Exception e) {
68+
listener.onFailure(e);
69+
}
6070
}
6171

6272
protected BulkInferenceExecutionConfig bulkExecutionConfig() {

0 commit comments

Comments
 (0)