Skip to content

Commit 9a26f61

Browse files
committed
Minor code improvement.
1 parent 3d0819d commit 9a26f61

File tree

7 files changed

+91
-100
lines changed

7 files changed

+91
-100
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ public class InferenceRunner {
2626
private final Client client;
2727
private final ThreadPool threadPool;
2828

29-
public InferenceRunner(Client client) {
29+
public InferenceRunner(Client client, ThreadPool threadPool) {
3030
this.client = client;
31-
// TODO: revisit the executor service instantiation and thread pool choice.
32-
this.threadPool = client.threadPool();
31+
this.threadPool = threadPool;
3332
}
3433

3534
public ThreadPool threadPool() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import java.util.concurrent.TimeUnit;
1313

1414
public record BulkInferenceExecutionConfig(TimeValue inferenceTimeout, int workers) {
15-
public static final TimeValue DEFAULT_INFERENCE_TIMEOUT = new TimeValue(10, TimeUnit.SECONDS);
15+
public static final TimeValue DEFAULT_INFERENCE_TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
1616
public static final int DEFAULT_WORKERS = 10;
1717

1818
public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig(DEFAULT_INFERENCE_TIMEOUT, DEFAULT_WORKERS);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,12 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10-
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1110
import org.elasticsearch.compute.operator.FailureCollector;
1211
import org.elasticsearch.index.seqno.LocalCheckpointTracker;
1312
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1413

1514
import java.util.Map;
16-
import java.util.concurrent.BlockingQueue;
1715
import java.util.concurrent.ConcurrentHashMap;
18-
import java.util.concurrent.TimeUnit;
19-
import java.util.concurrent.TimeoutException;
2016
import java.util.concurrent.atomic.AtomicBoolean;
2117

2218
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
@@ -25,7 +21,6 @@ public class BulkInferenceExecutionState {
2521
private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
2622
private final FailureCollector failureCollector = new FailureCollector();
2723
private final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
28-
private final BlockingQueue<Long> processedSeqNoQueue = ConcurrentCollections.newBlockingQueue();
2924
private final AtomicBoolean finished = new AtomicBoolean(false);
3025

3126
public long generateSeqNo() {
@@ -45,39 +40,25 @@ public long getMaxSeqNo() {
4540
}
4641

4742
public void onInferenceResponse(long seqNo, InferenceAction.Response response) {
43+
assert seqNo <= getMaxSeqNo();
4844
if (failureCollector.hasFailure() == false) {
4945
bufferedResponses.put(seqNo, response);
5046
}
5147
checkpoint.markSeqNoAsProcessed(seqNo);
52-
processedSeqNoQueue.offer(seqNo);
5348
}
5449

5550
public void onInferenceException(long seqNo, Exception e) {
51+
assert seqNo <= getMaxSeqNo();
5652
failureCollector.unwrapAndCollect(e);
5753
checkpoint.markSeqNoAsProcessed(seqNo);
58-
processedSeqNoQueue.offer(seqNo);
59-
}
60-
61-
public long fetchProcessedSeqNo(int retry) throws InterruptedException, TimeoutException {
62-
while (retry > 0) {
63-
if (finished()) {
64-
return -1;
65-
}
66-
retry--;
67-
Long seqNo = processedSeqNoQueue.poll(1, TimeUnit.SECONDS);
68-
if (seqNo != null) {
69-
return seqNo;
70-
}
71-
}
72-
73-
throw new TimeoutException("timeout waiting for inference response");
7454
}
7555

7656
public InferenceAction.Response fetchBufferedResponse(long seqNo) {
7757
return bufferedResponses.remove(seqNo);
7858
}
7959

8060
public void markSeqNoAsPersisted(long seqNo) {
61+
assert seqNo <= getMaxSeqNo();
8162
checkpoint.markSeqNoAsPersisted(seqNo);
8263
}
8364

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 43 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,80 @@
1717
import java.util.ArrayList;
1818
import java.util.List;
1919
import java.util.concurrent.ExecutorService;
20-
import java.util.concurrent.RejectedExecutionException;
21-
import java.util.concurrent.TimeoutException;
20+
import java.util.concurrent.atomic.AtomicBoolean;
2221

2322
public class BulkInferenceExecutor {
2423
private static final String TASK_RUNNER_NAME = "bulk_inference_operation";
25-
private static final int INFERENCE_RESPONSE_TIMEOUT = 30; // TODO: should be in the config.
2624
private final ThrottledInferenceRunner throttledInferenceRunner;
27-
private final ExecutorService executorService;
2825

2926
public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) {
30-
executorService = executorService(threadPool);
31-
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService, bulkExecutionConfig);
27+
this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig);
3228
}
3329

3430
public void execute(BulkInferenceRequestIterator requests, ActionListener<InferenceAction.Response[]> listener) throws Exception {
35-
final ResponseHandler responseHandler = new ResponseHandler();
36-
runInferenceRequests(requests, listener.delegateFailureAndWrap(responseHandler::handleResponses));
37-
}
31+
if (requests.hasNext() == false) {
32+
listener.onResponse(new InferenceAction.Response[0]);
33+
return;
34+
}
3835

39-
private void runInferenceRequests(BulkInferenceRequestIterator requests, ActionListener<BulkInferenceExecutionState> listener) {
4036
final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState();
41-
try {
42-
executorService.execute(() -> {
43-
while (bulkExecutionState.finished() == false && requests.hasNext()) {
44-
InferenceAction.Request request = requests.next();
45-
long seqNo = bulkExecutionState.generateSeqNo();
46-
throttledInferenceRunner.doInference(
47-
request,
48-
ActionListener.wrap(
49-
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
50-
e -> bulkExecutionState.onInferenceException(seqNo, e)
51-
)
52-
);
53-
}
54-
bulkExecutionState.finish();
55-
});
56-
} catch (RejectedExecutionException e) {
57-
bulkExecutionState.addFailure(new IllegalStateException("Unable to enqueue inference requests", e));
58-
bulkExecutionState.finish();
59-
} finally {
60-
listener.onResponse(bulkExecutionState);
37+
final ResponseHandler responseHandler = new ResponseHandler(bulkExecutionState, listener);
38+
39+
while (bulkExecutionState.finished() == false && requests.hasNext()) {
40+
InferenceAction.Request request = requests.next();
41+
long seqNo = bulkExecutionState.generateSeqNo();
42+
throttledInferenceRunner.doInference(request, responseHandler.inferenceResponseListener(seqNo));
6143
}
44+
bulkExecutionState.finish();
6245
}
6346

6447
private static class ResponseHandler {
6548
private final List<InferenceAction.Response> responses = new ArrayList<>();
49+
private final BulkInferenceExecutionState bulkExecutionState;
50+
private final ActionListener<InferenceAction.Response[]> completionListener;
51+
private final AtomicBoolean responseSent = new AtomicBoolean(false);
6652

67-
private void handleResponses(ActionListener<InferenceAction.Response[]> listener, BulkInferenceExecutionState bulkExecutionState) {
68-
69-
try {
70-
persistsInferenceResponses(bulkExecutionState);
71-
} catch (InterruptedException | TimeoutException e) {
72-
bulkExecutionState.addFailure(e);
73-
bulkExecutionState.finish();
74-
}
75-
76-
if (bulkExecutionState.hasFailure() == false) {
77-
try {
78-
listener.onResponse(responses.toArray(InferenceAction.Response[]::new));
79-
return;
80-
} catch (Exception e) {
81-
bulkExecutionState.addFailure(e);
82-
}
83-
}
53+
private ResponseHandler(BulkInferenceExecutionState bulkExecutionState, ActionListener<InferenceAction.Response[]> completionListener) {
54+
this.bulkExecutionState = bulkExecutionState;
55+
this.completionListener = completionListener;
56+
}
8457

85-
listener.onFailure(bulkExecutionState.getFailure());
58+
ActionListener<InferenceAction.Response> inferenceResponseListener(long seqNo) {
59+
return ActionListener.runAfter(ActionListener.wrap(
60+
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
61+
e -> bulkExecutionState.onInferenceException(seqNo, e)
62+
), this::persistPendingResponses);
8663
}
8764

88-
private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutionState) throws TimeoutException,
89-
InterruptedException {
90-
while (bulkExecutionState.finished() == false && bulkExecutionState.fetchProcessedSeqNo(INFERENCE_RESPONSE_TIMEOUT) >= 0) {
65+
private void persistPendingResponses() {
66+
synchronized (bulkExecutionState) {
9167
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
9268

9369
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
9470
persistedSeqNo++;
9571
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
9672
assert response != null || bulkExecutionState.hasFailure();
73+
9774
if (bulkExecutionState.hasFailure() == false) {
98-
try {
99-
responses.add(response);
100-
} catch (Exception e) {
101-
bulkExecutionState.addFailure(e);
102-
}
75+
responses.add(response);
10376
}
77+
10478
bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo);
10579
}
10680
}
81+
82+
sendResponseOnCompletion();
83+
}
84+
85+
private void sendResponseOnCompletion() {
86+
if (bulkExecutionState.finished() && responseSent.compareAndSet(false, true)) {
87+
if (bulkExecutionState.hasFailure() == false) {
88+
completionListener.onResponse(responses.toArray(InferenceAction.Response[]::new));
89+
return;
90+
}
91+
92+
completionListener.onFailure(bulkExecutionState.getFailure());
93+
}
10794
}
10895
}
10996

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ public TransportEsqlQueryAction(
155155
projectResolver,
156156
indexNameExpressionResolver,
157157
usageService,
158-
new InferenceRunner(client)
158+
new InferenceRunner(client, threadPool)
159159
);
160160

161161
this.computeService = new ComputeService(

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void shutdownThreadPool() {
6565
}
6666

6767
public void testResolveInferenceIds() throws Exception {
68-
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
68+
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
6969
List<InferencePlan<?>> inferencePlans = List.of(mockInferencePlan("rerank-plan"));
7070
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
7171

@@ -82,7 +82,7 @@ public void testResolveInferenceIds() throws Exception {
8282
}
8383

8484
public void testResolveMultipleInferenceIds() throws Exception {
85-
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
85+
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
8686
List<InferencePlan<?>> inferencePlans = List.of(
8787
mockInferencePlan("rerank-plan"),
8888
mockInferencePlan("rerank-plan"),
@@ -110,7 +110,7 @@ public void testResolveMultipleInferenceIds() throws Exception {
110110
}
111111

112112
public void testResolveMissingInferenceIds() throws Exception {
113-
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
113+
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
114114
List<InferencePlan<?>> inferencePlans = List.of(mockInferencePlan("missing-plan"));
115115

116116
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
@@ -132,7 +132,6 @@ public void testResolveMissingInferenceIds() throws Exception {
132132
@SuppressWarnings({ "unchecked", "raw-types" })
133133
private Client mockClient() {
134134
Client client = mock(Client.class);
135-
when(client.threadPool()).thenReturn(threadPool);
136135
doAnswer(i -> {
137136
Runnable sendResponse = () -> {
138137
GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import java.util.Iterator;
2727
import java.util.List;
28+
import java.util.concurrent.ExecutorService;
2829
import java.util.concurrent.TimeUnit;
2930
import java.util.concurrent.atomic.AtomicReference;
3031
import java.util.stream.Stream;
@@ -67,8 +68,11 @@ public void testSuccessfulExecution() throws Exception {
6768
InferenceAction.Response[] responses = randomInferenceResponseList(requests.size()).toArray(InferenceAction.Response[]::new);
6869

6970
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
70-
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
71-
l.onResponse(responses[requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))]);
71+
runWithRandomDelay(() -> {
72+
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
73+
l.onResponse(responses[requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))]);
74+
});
75+
7276
return null;
7377
});
7478

@@ -96,39 +100,50 @@ public void testInferenceRunnerAlwaysFails() throws Exception {
96100
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 30));
97101

98102
InferenceRunner inferenceRunner = mock(invocation -> {
99-
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
100-
listener.onFailure(new RuntimeException("inference failure"));
103+
runWithRandomDelay(() -> {
104+
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
105+
listener.onFailure(new RuntimeException("inference failure"));
106+
});
107+
101108
return null;
102109
});
103110

104111
AtomicReference<Exception> exception = new AtomicReference<>();
105-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
112+
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expected exception"), exception::set);
106113

107114
bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener);
108115

109-
assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("inference failure")));
116+
assertBusy(() -> {
117+
assertThat(exception.get(), notNullValue());
118+
assertThat(exception.get().getMessage(), equalTo("inference failure"));
119+
});
110120
}
111121

112122
public void testInferenceRunnerSometimesFails() throws Exception {
113123
List<InferenceAction.Request> requests = randomInferenceRequestList(between(2, 30));
114124

115125
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
116126
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
117-
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
118-
listener.onFailure(new RuntimeException("inference failure"));
119-
} else {
120-
listener.onResponse(mockInferenceResponse());
121-
}
127+
runWithRandomDelay(() -> {
128+
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
129+
listener.onFailure(new RuntimeException("inference failure"));
130+
} else {
131+
listener.onResponse(mockInferenceResponse());
132+
}
133+
});
122134

123135
return null;
124136
});
125137

126138
AtomicReference<Exception> exception = new AtomicReference<>();
127-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
139+
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expected exception"), exception::set);
128140

129141
bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener);
130142

131-
assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("inference failure")));
143+
assertBusy(() -> {
144+
assertThat(exception.get(), notNullValue());
145+
assertThat(exception.get().getMessage(), equalTo("inference failure"));
146+
});
132147
}
133148

134149
private BulkInferenceExecutor bulkExecutor(InferenceRunner inferenceRunner) {
@@ -170,4 +185,14 @@ private InferenceRunner mockInferenceRunner(Answer<Void> doInferenceAnswer) {
170185
doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any());
171186
return inferenceRunner;
172187
}
188+
189+
private void runWithRandomDelay(Runnable runnable) {
190+
if (randomBoolean()) {
191+
runnable.run();
192+
} else {
193+
ExecutorService executor = threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME);
194+
TimeValue delay = TimeValue.timeValueNanos(between(1, 1_000));
195+
threadPool.schedule(runnable, delay, executor);
196+
}
197+
}
173198
}

0 commit comments

Comments
 (0)