Skip to content

Commit d3f7a0e

Browse files
committed
Fix out of memory error in the CompletionOperatorTests
1 parent 3a8422d commit d3f7a0e

File tree

10 files changed

+198
-55
lines changed

10 files changed

+198
-55
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ protected void releaseFetchedOnAnyThread(OngoingInference result) {
5959
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
6060
try {
6161
BulkInferenceRequestIterator requests = requests(input);
62-
listener = ActionListener.releaseAfter(listener, requests);
62+
listener = ActionListener.releaseBefore(requests, listener);
6363
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInference(input, responses)));
6464
} catch (Exception e) {
6565
listener.onFailure(e);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
package org.elasticsearch.xpack.esql.inference.bulk;
99

1010
public record BulkInferenceExecutionConfig(int workers, int maxOutstandingRequests) {
11-
public static final int DEFAULT_WORKERS = 2;
11+
public static final int DEFAULT_WORKERS = 10;
1212
public static final int DEFAULT_MAX_OUTSTANDING_REQUESTS = 50;
1313

14-
public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig(DEFAULT_WORKERS, DEFAULT_MAX_OUTSTANDING_REQUESTS);
14+
public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig(
15+
DEFAULT_WORKERS,
16+
DEFAULT_MAX_OUTSTANDING_REQUESTS
17+
);
1518
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,20 @@ public long getMaxSeqNo() {
3939
return checkpoint.getMaxSeqNo();
4040
}
4141

42-
public void onInferenceResponse(long seqNo, InferenceAction.Response response) {
42+
public synchronized void onInferenceResponse(long seqNo, InferenceAction.Response response) {
4343
if (failureCollector.hasFailure() == false) {
4444
bufferedResponses.put(seqNo, response);
4545
}
4646
checkpoint.markSeqNoAsProcessed(seqNo);
4747
}
4848

49-
public void onInferenceException(long seqNo, Exception e) {
49+
public synchronized void onInferenceException(long seqNo, Exception e) {
5050
failureCollector.unwrapAndCollect(e);
5151
checkpoint.markSeqNoAsProcessed(seqNo);
52+
bufferedResponses.clear();
5253
}
5354

54-
public InferenceAction.Response fetchBufferedResponse(long seqNo) {
55+
public synchronized InferenceAction.Response fetchBufferedResponse(long seqNo) {
5556
return bufferedResponses.remove(seqNo);
5657
}
5758

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void execute(BulkInferenceRequestIterator requests, ActionListener<List<I
5353
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
5454
e -> bulkExecutionState.onInferenceException(seqNo, e)
5555
),
56-
responseHandler::persistsInferenceResponses
56+
responseHandler::persistPendingResponses
5757
)
5858
);
5959
}
@@ -75,15 +75,15 @@ private ResponseHandler(
7575
this.responses = new ArrayList<>(estimatedSize);
7676
}
7777

78-
public synchronized void persistsInferenceResponses() {
78+
public synchronized void persistPendingResponses() {
7979
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
8080

8181
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
8282
persistedSeqNo++;
83-
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
84-
assert response != null || bulkExecutionState.hasFailure();
8583
if (bulkExecutionState.hasFailure() == false) {
8684
try {
85+
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
86+
assert response != null;
8787
responses.add(response);
8888
} catch (Exception e) {
8989
bulkExecutionState.addFailure(e);
@@ -148,7 +148,7 @@ private void executePendingRequests() {
148148

149149
try {
150150
executorService.execute(task);
151-
} catch (Exception e){
151+
} catch (Exception e) {
152152
task.onFailure(e);
153153
permits.release();
154154
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
5656

5757
public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
5858
implements
59-
OperatorFactory {
59+
OperatorFactory {
6060
@Override
6161
public String describe() {
6262
return "CompletionOperator[inference_id=[" + inferenceId + "]]";

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRequest;
13+
import org.elasticsearch.action.ActionResponse;
14+
import org.elasticsearch.action.ActionType;
15+
import org.elasticsearch.client.internal.Client;
1216
import org.elasticsearch.common.logging.LoggerMessageFormat;
1317
import org.elasticsearch.common.settings.Settings;
1418
import org.elasticsearch.common.util.concurrent.EsExecutors;
@@ -28,8 +32,8 @@
2832
import org.elasticsearch.compute.operator.SourceOperator;
2933
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
3034
import org.elasticsearch.compute.test.OperatorTestCase;
31-
import org.elasticsearch.core.TimeValue;
3235
import org.elasticsearch.inference.InferenceServiceResults;
36+
import org.elasticsearch.test.client.NoOpClient;
3337
import org.elasticsearch.threadpool.FixedExecutorBuilder;
3438
import org.elasticsearch.threadpool.TestThreadPool;
3539
import org.elasticsearch.threadpool.ThreadPool;
@@ -44,10 +48,6 @@
4448
import static org.hamcrest.Matchers.equalTo;
4549
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
4650
import static org.hamcrest.Matchers.notNullValue;
47-
import static org.mockito.ArgumentMatchers.any;
48-
import static org.mockito.Mockito.doAnswer;
49-
import static org.mockito.Mockito.mock;
50-
import static org.mockito.Mockito.when;
5151

5252
public abstract class InferenceOperatorTestCase<InferenceResultsType extends InferenceServiceResults> extends OperatorTestCase {
5353
private ThreadPool threadPool;
@@ -110,33 +110,26 @@ public void testOperatorStatus() {
110110
}
111111
}
112112

113+
@SuppressWarnings("unchecked")
113114
protected InferenceRunner mockedSimpleInferenceRunner() {
114-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
115-
when(inferenceRunner.threadPool()).thenReturn(threadPool());
116-
doAnswer(i -> {
117-
Runnable sendResponse = () -> {
118-
@SuppressWarnings("unchecked")
119-
ActionListener<InferenceAction.Response> listener = i.getArgument(1, ActionListener.class);
120-
InferenceAction.Request request = i.getArgument(0, InferenceAction.Request.class);
121-
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
122-
doAnswer(invocation -> mockInferenceResult(request)).when(inferenceResponse).getResults();
123-
124-
listener.onResponse(inferenceResponse);
125-
};
126-
127-
if (randomBoolean()) {
128-
sendResponse.run();
129-
} else {
130-
threadPool.schedule(
131-
sendResponse,
132-
TimeValue.timeValueNanos(between(1, 100)),
133-
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
134-
);
115+
Client client = new NoOpClient(threadPool) {
116+
@Override
117+
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
118+
ActionType<Response> action,
119+
Request request,
120+
ActionListener<Response> listener
121+
) {
122+
if (action == InferenceAction.INSTANCE && request instanceof InferenceAction.Request inferenceRequest) {
123+
InferenceAction.Response inferenceResponse = new InferenceAction.Response(mockInferenceResult(inferenceRequest));
124+
listener.onResponse((Response) inferenceResponse);
125+
return;
126+
}
127+
128+
fail("Unexpected call to action [" + action.name() + "]");
135129
}
130+
};
136131

137-
return null;
138-
}).when(inferenceRunner).doInference(any(InferenceAction.Request.class), any());
139-
return inferenceRunner;
132+
return new InferenceRunner(client, threadPool);
140133
}
141134

142135
protected abstract InferenceResultsType mockInferenceResult(InferenceAction.Request request);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public void shutdownThreadPool() {
6161
}
6262

6363
public void testSuccessfulExecution() throws Exception {
64-
List<InferenceAction.Request> requests = randomInferenceRequestList(10_000);
64+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 100_000));
6565
List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size());
6666

6767
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
@@ -93,7 +93,7 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
9393
}
9494

9595
public void testInferenceRunnerAlwaysFails() throws Exception {
96-
List<InferenceAction.Request> requests = randomInferenceRequestList(10_000);
96+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 100_000));
9797

9898
InferenceRunner inferenceRunner = mock(invocation -> {
9999
runWithRandomDelay(() -> {
@@ -115,7 +115,7 @@ public void testInferenceRunnerAlwaysFails() throws Exception {
115115
}
116116

117117
public void testInferenceRunnerSometimesFails() throws Exception {
118-
List<InferenceAction.Request> requests = randomInferenceRequestList(10_000);
118+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 100_000));
119119

120120
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
121121
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
@@ -180,7 +180,7 @@ private List<InferenceAction.Request> randomInferenceRequestList(int size) {
180180
private List<InferenceAction.Response> randomInferenceResponseList(int size) {
181181
List<InferenceAction.Response> response = new ArrayList<>(size);
182182
while (response.size() < size) {
183-
response.add(this.mockInferenceResponse());
183+
response.add(mock(InferenceAction.Response.class));
184184
}
185185
return response;
186186
}
@@ -197,7 +197,7 @@ private void runWithRandomDelay(Runnable runnable) {
197197
} else {
198198
threadPool.schedule(
199199
runnable,
200-
TimeValue.timeValueNanos(between(1, 100)),
200+
TimeValue.timeValueNanos(between(1, 1_000)),
201201
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
202202
);
203203
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.completion;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.Block;
12+
import org.elasticsearch.compute.data.BytesRefBlock;
13+
import org.elasticsearch.compute.data.Page;
14+
import org.elasticsearch.compute.test.ComputeTestCase;
15+
import org.elasticsearch.compute.test.RandomBlock;
16+
import org.elasticsearch.core.Releasables;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
19+
20+
import java.util.List;
21+
22+
import static org.hamcrest.Matchers.equalTo;
23+
24+
public class CompletionOperatorOutputBuilderTests extends ComputeTestCase {
25+
26+
public void testBuildSmallOutput() {
27+
assertBuildOutput(between(1, 100));
28+
}
29+
30+
public void testBuildLargeOutput() {
31+
assertBuildOutput(between(10_000, 100_000));
32+
}
33+
34+
private void assertBuildOutput(int size) {
35+
Page inputPage = randomInputPage(size, between(1, 20));
36+
try (
37+
CompletionOperatorOutputBuilder outputBuilder = new CompletionOperatorOutputBuilder(
38+
blockFactory().newBytesRefBlockBuilder(size),
39+
inputPage
40+
)
41+
) {
42+
for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) {
43+
List<ChatCompletionResults.Result> results = List.of(new ChatCompletionResults.Result("Completion result #" + currentPos));
44+
outputBuilder.addInferenceResponse(new InferenceAction.Response(new ChatCompletionResults(results)));
45+
}
46+
47+
Page outputPage = outputBuilder.buildOutput();
48+
assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount()));
49+
assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1));
50+
assertOutputContent(outputPage.getBlock(outputPage.getBlockCount() - 1));
51+
52+
outputPage.releaseBlocks();
53+
54+
} finally {
55+
inputPage.releaseBlocks();
56+
}
57+
58+
}
59+
60+
private void assertOutputContent(BytesRefBlock block) {
61+
BytesRef scratch = new BytesRef();
62+
63+
for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) {
64+
assertThat(block.isNull(currentPos), equalTo(false));
65+
scratch = block.getBytesRef(block.getFirstValueIndex(currentPos), scratch);
66+
assertThat(scratch.utf8ToString(), equalTo("Completion result #" + currentPos));
67+
}
68+
}
69+
70+
private Page randomInputPage(int positionCount, int columnCount) {
71+
Block[] blocks = new Block[columnCount];
72+
try {
73+
for (int i = 0; i < columnCount; i++) {
74+
blocks[i] = RandomBlock.randomBlock(
75+
blockFactory(),
76+
RandomBlock.randomElementType(),
77+
positionCount,
78+
randomBoolean(),
79+
0,
80+
0,
81+
randomInt(10),
82+
randomInt(10)
83+
).block();
84+
}
85+
86+
return new Page(blocks);
87+
} catch (Exception e) {
88+
Releasables.close(blocks);
89+
throw (e);
90+
}
91+
}
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.completion;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.BytesRefBlock;
12+
import org.elasticsearch.compute.test.ComputeTestCase;
13+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
15+
import static org.hamcrest.Matchers.equalTo;
16+
17+
public class CompletionOperatorRequestIteratorTests extends ComputeTestCase {
18+
19+
public void testIterateSmallInput() {
20+
assertIterate(between(1, 100));
21+
}
22+
23+
public void testIterateLargeInput() {
24+
assertIterate(between(10_000, 100_000));
25+
}
26+
27+
private void assertIterate(int size) {
28+
final BytesRefBlock inputBlock = randomInputBlock(size);
29+
final String inferenceId = randomIdentifier();
30+
31+
try (CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId)) {
32+
BytesRef scratch = new BytesRef();
33+
34+
for (int currentPos = 0; requestIterator.hasNext(); currentPos++) {
35+
InferenceAction.Request request = requestIterator.next();
36+
assertThat(request.getInferenceEntityId(), equalTo(inferenceId));
37+
scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch);
38+
assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString()));
39+
}
40+
}
41+
}
42+
43+
private BytesRefBlock randomInputBlock(int size) {
44+
try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) {
45+
for (int i = 0; i < size; i++) {
46+
blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10)));
47+
}
48+
49+
return blockBuilder.build();
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)