Skip to content

Commit 3916eb9

Browse files
committed
Another refactoring
1 parent fe091b4 commit 3916eb9

File tree

7 files changed

+103
-105
lines changed

7 files changed

+103
-105
lines changed

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,6 @@ tests:
306306
- class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
307307
method: testSearchWithRandomDisconnects
308308
issue: https://github.com/elastic/elasticsearch/issues/122707
309-
#- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
310-
# method: testSimpleCircuitBreaking
311-
# issue: https://github.com/elastic/elasticsearch/issues/124337
312309
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
313310
method: testSchedulerCloseWaitsForRunningMerge
314311
issue: https://github.com/elastic/elasticsearch/issues/125236

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,7 @@ public Page getOutput() {
5656

5757
@Override
5858
protected void performAsync(Page input, ActionListener<Page> listener) {
59-
try {
60-
BulkInferenceOutputBuilder<InferenceResult, Page> outputBuilder = outputBuilder(input);
61-
listener = ActionListener.releaseBefore(outputBuilder, listener);
62-
63-
BulkInferenceRequestIterator requests = requests(input);
64-
listener = ActionListener.releaseBefore(requests, listener);
65-
59+
try (OutputBuilder<InferenceResult> outputBuilder = outputBuilder(input); BulkInferenceRequestIterator requests = requests(input)) {
6660
bulkInferenceExecutor.execute(requests, outputBuilder, listener);
6761
} catch (Exception e) {
6862
listener.onFailure(e);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 32 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,71 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10-
import org.elasticsearch.action.ActionListener;
10+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1111
import org.elasticsearch.compute.operator.FailureCollector;
12-
import org.elasticsearch.core.CheckedConsumer;
13-
import org.elasticsearch.core.CheckedSupplier;
1412
import org.elasticsearch.index.seqno.LocalCheckpointTracker;
1513
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1614

1715
import java.util.Map;
16+
import java.util.concurrent.BlockingQueue;
1817
import java.util.concurrent.ConcurrentHashMap;
19-
import java.util.concurrent.atomic.AtomicBoolean;
18+
import java.util.concurrent.TimeUnit;
2019

2120
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
2221

2322
public class BulkInferenceExecutionState {
2423
private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
2524
private final FailureCollector failureCollector = new FailureCollector();
26-
private final AtomicBoolean responseSent = new AtomicBoolean(false);
27-
private final AtomicBoolean allRequestsSent = new AtomicBoolean(false);
28-
final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
25+
private final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
26+
private final BlockingQueue<Long> processedSeqNoQueue = ConcurrentCollections.newBlockingQueue();
2927

3028
public long generateSeqNo() {
3129
return checkpoint.generateSeqNo();
3230
}
3331

32+
public long getPersistedCheckpoint() {
33+
return checkpoint.getPersistedCheckpoint();
34+
}
35+
36+
public long getProcessedCheckpoint() {
37+
return checkpoint.getProcessedCheckpoint();
38+
}
39+
40+
public long getMaxSeqNo() {
41+
return checkpoint.getMaxSeqNo();
42+
}
43+
3444
public void onInferenceResponse(long seqNo, InferenceAction.Response response) {
3545
if (failureCollector.hasFailure() == false) {
3646
bufferedResponses.put(seqNo, response);
3747
}
3848
checkpoint.markSeqNoAsProcessed(seqNo);
49+
processedSeqNoQueue.offer(seqNo);
3950
}
4051

4152
public void onInferenceException(long seqNo, Exception e) {
4253
failureCollector.unwrapAndCollect(e);
4354
checkpoint.markSeqNoAsProcessed(seqNo);
55+
processedSeqNoQueue.offer(seqNo);
4456
}
4557

46-
public void markSeqNoAsPersisted(long seqNo) {
47-
checkpoint.markSeqNoAsPersisted(seqNo);
58+
public Long fetchProcessedSeqNo() {
59+
try {
60+
return processedSeqNoQueue.poll(1, TimeUnit.SECONDS);
61+
} catch (InterruptedException e) {
62+
return null;
63+
}
4864
}
4965

50-
public void persistsResponses(CheckedConsumer<InferenceAction.Response, Exception> persister) {
51-
synchronized (checkpoint) {
52-
long persistedSeqNo = checkpoint.getPersistedCheckpoint();
53-
while (persistedSeqNo < checkpoint.getProcessedCheckpoint()) {
54-
persistedSeqNo++;
55-
InferenceAction.Response response = bufferedResponses.remove(persistedSeqNo);
56-
assert response != null || hasFailure();
57-
if (hasFailure() == false && responseSent() == false) {
58-
try {
59-
persister.accept(response);
60-
} catch (Exception e) {
61-
failureCollector.unwrapAndCollect(e);
62-
}
63-
}
64-
checkpoint.markSeqNoAsPersisted(persistedSeqNo);
65-
}
66-
}
66+
public InferenceAction.Response fetchBufferedResponse(long seqNo) {
67+
return bufferedResponses.remove(seqNo);
68+
}
69+
70+
public void markSeqNoAsPersisted(long seqNo) {
71+
checkpoint.markSeqNoAsPersisted(seqNo);
6772
}
6873

69-
private boolean hasFailure() {
74+
public boolean hasFailure() {
7075
return failureCollector.hasFailure();
7176
}
7277

@@ -77,40 +82,4 @@ public Exception getFailure() {
7782
public void addFailure(Exception e) {
7883
failureCollector.unwrapAndCollect(e);
7984
}
80-
81-
public <OutputType> void maybeSendResponse(CheckedSupplier<OutputType, Exception> responseBuilder, ActionListener<OutputType> l) {
82-
if (allRequestsSent() == false) {
83-
return;
84-
}
85-
86-
if (checkpoint.getPersistedCheckpoint() < checkpoint.getMaxSeqNo()) {
87-
return;
88-
}
89-
90-
if (responseSent.compareAndSet(false, true)) {
91-
if (failureCollector.hasFailure() == false) {
92-
try {
93-
l.onResponse(responseBuilder.get());
94-
return;
95-
} catch (Exception e) {
96-
l.onFailure(e);
97-
return;
98-
}
99-
}
100-
101-
l.onFailure(failureCollector.getFailure());
102-
}
103-
}
104-
105-
public boolean responseSent() {
106-
return responseSent.get();
107-
}
108-
109-
public void markAllRequestsSent() {
110-
allRequestsSent.set(true);
111-
}
112-
113-
public boolean allRequestsSent() {
114-
return allRequestsSent.get();
115-
}
11685
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
12+
import org.elasticsearch.core.CheckedConsumer;
1213
import org.elasticsearch.inference.InferenceServiceResults;
1314
import org.elasticsearch.threadpool.ThreadPool;
1415
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
@@ -32,37 +33,75 @@ public <InferenceResult extends InferenceServiceResults, OutputType> void execut
3233
BulkInferenceOutputBuilder<InferenceResult, OutputType> outputBuilder,
3334
ActionListener<OutputType> listener
3435
) {
36+
if (requests.hasNext() == false) {
37+
listener.onResponse(outputBuilder.buildOutput());
38+
return;
39+
}
40+
3541
final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState();
3642

37-
if (requests.hasNext() == false) {
38-
bulkExecutionState.markAllRequestsSent();
39-
bulkExecutionState.maybeSendResponse(outputBuilder::buildOutput, listener);
43+
try {
44+
enqueueRequests(requests, bulkExecutionState);
45+
persistsInferenceResponses(bulkExecutionState, outputBuilder::onInferenceResponse);
46+
} catch (Exception e) {
47+
listener.onFailure(e);
4048
}
4149

42-
while (requests.hasNext() && bulkExecutionState.responseSent() == false) {
43-
long seqNo = bulkExecutionState.generateSeqNo();
44-
InferenceAction.Request request = requests.next();
50+
if (bulkExecutionState.hasFailure() == false) {
4551
try {
46-
throttledInferenceRunner.doInference(
47-
request,
48-
ActionListener.runAfter(
49-
ActionListener.wrap(
50-
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
51-
e -> bulkExecutionState.onInferenceException(seqNo, e)
52-
),
53-
() -> {
54-
bulkExecutionState.persistsResponses(outputBuilder::onInferenceResponse);
55-
bulkExecutionState.maybeSendResponse(outputBuilder::buildOutput, listener);
56-
}
57-
)
58-
);
59-
} catch (InterruptedException | TimeoutException e) {
60-
bulkExecutionState.addFailure(e);
61-
bulkExecutionState.maybeSendResponse(outputBuilder::buildOutput, listener);
52+
listener.onResponse(outputBuilder.buildOutput());
53+
return;
54+
} catch (Exception e) {
55+
listener.onFailure(e);
56+
return;
6257
}
6358
}
6459

65-
bulkExecutionState.markAllRequestsSent();
60+
listener.onFailure(bulkExecutionState.getFailure());
61+
}
62+
63+
private void enqueueRequests(BulkInferenceRequestIterator requests, BulkInferenceExecutionState bulkExecutionState) {
64+
while (requests.hasNext()) {
65+
InferenceAction.Request request = requests.next();
66+
long seqNo = bulkExecutionState.generateSeqNo();
67+
ActionListener<InferenceAction.Response> listener = ActionListener.wrap(
68+
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
69+
e -> bulkExecutionState.onInferenceException(seqNo, e)
70+
);
71+
throttledInferenceRunner.doInference(request, listener);
72+
}
73+
}
74+
75+
private void persistsInferenceResponses(
76+
BulkInferenceExecutionState bulkExecutionState,
77+
CheckedConsumer<InferenceAction.Response, Exception> persister
78+
) throws TimeoutException {
79+
// TODO: retry should be from config
80+
int retry = 30;
81+
while (bulkExecutionState.getPersistedCheckpoint() < bulkExecutionState.getMaxSeqNo()) {
82+
Long seqNo = bulkExecutionState.fetchProcessedSeqNo();
83+
retry--;
84+
85+
if (seqNo == null && retry < 0) {
86+
throw new TimeoutException("timeout waiting for inference response");
87+
}
88+
89+
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
90+
91+
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
92+
persistedSeqNo++;
93+
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
94+
assert response != null || bulkExecutionState.hasFailure();
95+
if (bulkExecutionState.hasFailure() == false) {
96+
try {
97+
persister.accept(response);
98+
} catch (Exception e) {
99+
bulkExecutionState.addFailure(e);
100+
}
101+
}
102+
bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo);
103+
}
104+
}
66105
}
67106

68107
private static class ThrottledInferenceRunner extends ThrottledTaskRunner {
@@ -85,17 +124,16 @@ public static ThrottledInferenceRunner create(
85124
return new ThrottledInferenceRunner(inferenceRunner, threadPool, bulkExecutionConfig.workers());
86125
}
87126

88-
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener)
89-
throws InterruptedException, TimeoutException {
127+
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
90128
this.enqueueTask(listener.delegateFailureAndWrap((l, releasable) -> {
91129
try (releasable) {
92130
inferenceRunner.doInference(request, l);
93131
}
94132
}));
95133
}
96-
}
97134

98-
private static ExecutorService executorService(ThreadPool threadPool) {
99-
return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME);
135+
private static ExecutorService executorService(ThreadPool threadPool) {
136+
return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME);
137+
}
100138
}
101139
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder,
3030

3131
@Override
3232
public void close() {
33-
Releasables.closeExpectNoException(outputBlockBuilder);
33+
Releasables.close(outputBlockBuilder);
3434
releasePageOnAnyThread(inputPage);
3535
}
3636

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public abstract class InferenceOperatorTestCase<InferenceResultsType extends Inf
6262

6363
@Before
6464
public void initChannels() {
65-
channelCount = between(2, 10);
65+
channelCount = between(2, 3);
6666
elementTypes = randomElementTypes(channelCount);
6767
}
6868

@@ -73,7 +73,7 @@ public void setThreadPool() {
7373
new FixedExecutorBuilder(
7474
Settings.EMPTY,
7575
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
76-
1,
76+
between(1, 10),
7777
1024,
7878
"esql",
7979
EsExecutors.TaskTrackingConfig.DEFAULT

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public void setThreadPool() {
5555
new FixedExecutorBuilder(
5656
Settings.EMPTY,
5757
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
58-
1,
58+
between(1, 20),
5959
1024,
6060
"esql",
6161
EsExecutors.TaskTrackingConfig.DEFAULT

0 commit comments

Comments
 (0)