Skip to content

Commit 340c189

Browse files
committed
Update throttling mechanism
1 parent 1e95722 commit 340c189

File tree

12 files changed

+264
-219
lines changed

12 files changed

+264
-219
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,8 @@
2525
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2626

2727
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInference> {
28-
29-
// Move to a setting.
30-
private static final int MAX_INFERENCE_WORKER = 10;
3128
private final String inferenceId;
3229
private final BlockFactory blockFactory;
33-
private final BulkInferenceExecutionConfig bulkExecutionConfig;
3430
private final BulkInferenceExecutor bulkInferenceExecutor;
3531

3632
public InferenceOperator(
@@ -40,9 +36,8 @@ public InferenceOperator(
4036
ThreadPool threadPool,
4137
String inferenceId
4238
) {
43-
super(driverContext, threadPool.getThreadContext(), MAX_INFERENCE_WORKER);
39+
super(driverContext, threadPool.getThreadContext(), bulkExecutionConfig.workers());
4440
this.blockFactory = driverContext.blockFactory();
45-
this.bulkExecutionConfig = bulkExecutionConfig;
4641
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig);
4742
this.inferenceId = inferenceId;
4843
}
@@ -60,6 +55,17 @@ protected void releaseFetchedOnAnyThread(OngoingInference result) {
6055
releasePageOnAnyThread(result.inputPage);
6156
}
6257

58+
@Override
59+
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
60+
try {
61+
BulkInferenceRequestIterator requests = requests(input);
62+
listener = ActionListener.releaseAfter(listener, requests);
63+
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInference(input, responses)));
64+
} catch (Exception e) {
65+
listener.onFailure(e);
66+
}
67+
}
68+
6369
@Override
6470
public Page getOutput() {
6571
OngoingInference ongoingInference = fetchFromBuffer();
@@ -75,19 +81,6 @@ public Page getOutput() {
7581
}
7682
}
7783

78-
@Override
79-
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
80-
try {
81-
bulkInferenceExecutor.execute(requests(input), listener.map(responses -> new OngoingInference(input, responses)));
82-
} catch (Exception e) {
83-
listener.onFailure(e);
84-
}
85-
}
86-
87-
protected BulkInferenceExecutionConfig bulkExecutionConfig() {
88-
return bulkExecutionConfig;
89-
}
90-
9184
protected abstract BulkInferenceRequestIterator requests(Page input);
9285

9386
protected abstract OutputBuilder outputBuilder(Page input);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10-
import org.elasticsearch.core.TimeValue;
10+
public record BulkInferenceExecutionConfig(int workers, int maxOutstandingRequests) {
11+
public static final int DEFAULT_WORKERS = 2;
12+
public static final int DEFAULT_MAX_OUTSTANDING_REQUESTS = 50;
1113

12-
import java.util.concurrent.TimeUnit;
13-
14-
public record BulkInferenceExecutionConfig(TimeValue inferenceTimeout, int workers) {
15-
public static final TimeValue DEFAULT_INFERENCE_TIMEOUT = new TimeValue(10, TimeUnit.SECONDS);
16-
public static final int DEFAULT_WORKERS = 10;
17-
18-
public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig(DEFAULT_INFERENCE_TIMEOUT, DEFAULT_WORKERS);
14+
public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig(DEFAULT_WORKERS, DEFAULT_MAX_OUTSTANDING_REQUESTS);
1915
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,12 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10-
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1110
import org.elasticsearch.compute.operator.FailureCollector;
1211
import org.elasticsearch.index.seqno.LocalCheckpointTracker;
1312
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1413

1514
import java.util.Map;
16-
import java.util.concurrent.BlockingQueue;
1715
import java.util.concurrent.ConcurrentHashMap;
18-
import java.util.concurrent.TimeUnit;
19-
import java.util.concurrent.TimeoutException;
2016
import java.util.concurrent.atomic.AtomicBoolean;
2117

2218
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
@@ -25,7 +21,6 @@ public class BulkInferenceExecutionState {
2521
private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
2622
private final FailureCollector failureCollector = new FailureCollector();
2723
private final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
28-
private final BlockingQueue<Long> processedSeqNoQueue = ConcurrentCollections.newBlockingQueue();
2924
private final AtomicBoolean finished = new AtomicBoolean(false);
3025

3126
public long generateSeqNo() {
@@ -49,28 +44,11 @@ public void onInferenceResponse(long seqNo, InferenceAction.Response response) {
4944
bufferedResponses.put(seqNo, response);
5045
}
5146
checkpoint.markSeqNoAsProcessed(seqNo);
52-
processedSeqNoQueue.offer(seqNo);
5347
}
5448

5549
public void onInferenceException(long seqNo, Exception e) {
5650
failureCollector.unwrapAndCollect(e);
5751
checkpoint.markSeqNoAsProcessed(seqNo);
58-
processedSeqNoQueue.offer(seqNo);
59-
}
60-
61-
public long fetchProcessedSeqNo(int retry) throws InterruptedException, TimeoutException {
62-
while (retry > 0) {
63-
if (finished()) {
64-
return -1;
65-
}
66-
retry--;
67-
Long seqNo = processedSeqNoQueue.poll(1, TimeUnit.SECONDS);
68-
if (seqNo != null) {
69-
return seqNo;
70-
}
71-
}
72-
73-
throw new TimeoutException("timeout waiting for inference response");
7452
}
7553

7654
public InferenceAction.Response fetchBufferedResponse(long seqNo) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 123 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,110 +8,118 @@
88
package org.elasticsearch.xpack.esql.inference.bulk;
99

1010
import org.elasticsearch.action.ActionListener;
11-
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
11+
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
12+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1213
import org.elasticsearch.threadpool.ThreadPool;
1314
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1415
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1516
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
1617

1718
import java.util.ArrayList;
1819
import java.util.List;
20+
import java.util.concurrent.BlockingQueue;
1921
import java.util.concurrent.ExecutorService;
20-
import java.util.concurrent.RejectedExecutionException;
21-
import java.util.concurrent.TimeoutException;
22+
import java.util.concurrent.Semaphore;
23+
import java.util.concurrent.atomic.AtomicBoolean;
2224

2325
public class BulkInferenceExecutor {
24-
private static final String TASK_RUNNER_NAME = "bulk_inference_operation";
25-
private static final int INFERENCE_RESPONSE_TIMEOUT = 30; // TODO: should be in the config.
2626
private final ThrottledInferenceRunner throttledInferenceRunner;
27-
private final ExecutorService executorService;
2827

2928
public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) {
30-
executorService = executorService(threadPool);
31-
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService, bulkExecutionConfig);
29+
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig);
3230
}
3331

3432
public void execute(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) throws Exception {
35-
final ResponseHandler responseHandler = new ResponseHandler();
36-
runInferenceRequests(requests, listener.delegateFailureAndWrap(responseHandler::handleResponses));
37-
}
33+
if (requests.hasNext() == false) {
34+
listener.onResponse(List.of());
35+
return;
36+
}
3837

39-
private void runInferenceRequests(BulkInferenceRequestIterator requests, ActionListener<BulkInferenceExecutionState> listener) {
4038
final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState();
41-
try {
42-
executorService.execute(() -> {
43-
while (bulkExecutionState.finished() == false && requests.hasNext()) {
44-
InferenceAction.Request request = requests.next();
45-
long seqNo = bulkExecutionState.generateSeqNo();
46-
throttledInferenceRunner.doInference(
47-
request,
48-
ActionListener.wrap(
49-
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
50-
e -> bulkExecutionState.onInferenceException(seqNo, e)
51-
)
52-
);
53-
}
39+
final ResponseHandler responseHandler = new ResponseHandler(bulkExecutionState, listener, requests.estimatedSize());
40+
41+
while (bulkExecutionState.finished() == false && requests.hasNext()) {
42+
InferenceAction.Request request = requests.next();
43+
long seqNo = bulkExecutionState.generateSeqNo();
44+
45+
if (requests.hasNext() == false) {
5446
bulkExecutionState.finish();
55-
});
56-
} catch (RejectedExecutionException e) {
57-
bulkExecutionState.addFailure(new IllegalStateException("Unable to enqueue inference requests", e));
58-
bulkExecutionState.finish();
59-
} finally {
60-
listener.onResponse(bulkExecutionState);
47+
}
48+
49+
throttledInferenceRunner.doInference(
50+
request,
51+
ActionListener.runAfter(
52+
ActionListener.wrap(
53+
r -> bulkExecutionState.onInferenceResponse(seqNo, r),
54+
e -> bulkExecutionState.onInferenceException(seqNo, e)
55+
),
56+
responseHandler::persistsInferenceResponses
57+
)
58+
);
6159
}
6260
}
6361

6462
private static class ResponseHandler {
65-
private final List<InferenceAction.Response> responses = new ArrayList<>();
66-
67-
private void handleResponses(ActionListener<List<InferenceAction.Response>> listener, BulkInferenceExecutionState bulkExecutionState) {
68-
69-
try {
70-
persistsInferenceResponses(bulkExecutionState);
71-
} catch (InterruptedException | TimeoutException e) {
72-
bulkExecutionState.addFailure(e);
73-
bulkExecutionState.finish();
74-
}
63+
private final List<InferenceAction.Response> responses;
64+
private final ActionListener<List<InferenceAction.Response>> listener;
65+
private final BulkInferenceExecutionState bulkExecutionState;
66+
private final AtomicBoolean responseSent = new AtomicBoolean(false);
67+
68+
private ResponseHandler(
69+
BulkInferenceExecutionState bulkExecutionState,
70+
ActionListener<List<InferenceAction.Response>> listener,
71+
int estimatedSize
72+
) {
73+
this.listener = listener;
74+
this.bulkExecutionState = bulkExecutionState;
75+
this.responses = new ArrayList<>(estimatedSize);
76+
}
7577

76-
if (bulkExecutionState.hasFailure() == false) {
77-
try {
78-
listener.onResponse(responses);
79-
return;
80-
} catch (Exception e) {
81-
bulkExecutionState.addFailure(e);
78+
public synchronized void persistsInferenceResponses() {
79+
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
80+
81+
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
82+
persistedSeqNo++;
83+
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
84+
assert response != null || bulkExecutionState.hasFailure();
85+
if (bulkExecutionState.hasFailure() == false) {
86+
try {
87+
responses.add(response);
88+
} catch (Exception e) {
89+
bulkExecutionState.addFailure(e);
90+
}
8291
}
92+
bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo);
8393
}
8494

85-
listener.onFailure(bulkExecutionState.getFailure());
95+
sendResponseOnCompletion();
8696
}
8797

88-
private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutionState) throws TimeoutException,
89-
InterruptedException {
90-
while (bulkExecutionState.finished() == false && bulkExecutionState.fetchProcessedSeqNo(INFERENCE_RESPONSE_TIMEOUT) >= 0) {
91-
long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint();
92-
93-
while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) {
94-
persistedSeqNo++;
95-
InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo);
96-
assert response != null || bulkExecutionState.hasFailure();
97-
if (bulkExecutionState.hasFailure() == false) {
98-
try {
99-
responses.add(response);
100-
} catch (Exception e) {
101-
bulkExecutionState.addFailure(e);
102-
}
98+
private void sendResponseOnCompletion() {
99+
if (bulkExecutionState.finished() && responseSent.compareAndSet(false, true)) {
100+
if (bulkExecutionState.hasFailure() == false) {
101+
try {
102+
listener.onResponse(responses);
103+
return;
104+
} catch (Exception e) {
105+
bulkExecutionState.addFailure(e);
103106
}
104-
bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo);
105107
}
108+
109+
listener.onFailure(bulkExecutionState.getFailure());
106110
}
107111
}
108112
}
109113

110-
private static class ThrottledInferenceRunner extends ThrottledTaskRunner {
114+
private static class ThrottledInferenceRunner {
111115
private final InferenceRunner inferenceRunner;
116+
private final ExecutorService executorService;
117+
private final BlockingQueue<AbstractRunnable> pendingRequests = ConcurrentCollections.newBlockingQueue();
118+
private final Semaphore permits;
112119

113120
private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) {
114-
super(TASK_RUNNER_NAME, maxRunningTasks, executorService);
121+
this.executorService = executorService;
122+
this.permits = new Semaphore(maxRunningTasks);
115123
this.inferenceRunner = inferenceRunner;
116124
}
117125

@@ -120,13 +128,58 @@ public static ThrottledInferenceRunner create(
120128
ExecutorService executorService,
121129
BulkInferenceExecutionConfig bulkExecutionConfig
122130
) {
123-
return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.workers());
131+
return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests());
124132
}
125133

126134
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
127-
this.enqueueTask(listener.delegateFailureAndWrap((l, releasable) -> {
128-
inferenceRunner.doInference(request, ActionListener.releaseAfter(l, releasable));
129-
}));
135+
enqueueTask(request, listener);
136+
executePendingRequests();
137+
}
138+
139+
private void executePendingRequests() {
140+
while (permits.tryAcquire()) {
141+
AbstractRunnable task = pendingRequests.poll();
142+
143+
if (task == null) {
144+
permits.release();
145+
return;
146+
}
147+
148+
try {
149+
executorService.execute(task);
150+
} catch (Exception e){
151+
task.onFailure(e);
152+
permits.release();
153+
}
154+
}
155+
}
156+
157+
private void enqueueTask(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
158+
try {
159+
pendingRequests.add(createTask(request, listener));
160+
executePendingRequests();
161+
} catch (Exception e) {
162+
listener.onFailure(new IllegalStateException("An error occurred while adding the inference request to the queue", e));
163+
}
164+
}
165+
166+
private AbstractRunnable createTask(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
167+
final ActionListener<InferenceAction.Response> completionListener = ActionListener.runAfter(listener, () -> {
168+
permits.release();
169+
executePendingRequests();
170+
});
171+
172+
return new AbstractRunnable() {
173+
@Override
174+
protected void doRun() {
175+
inferenceRunner.doInference(request, completionListener);
176+
}
177+
178+
@Override
179+
public void onFailure(Exception e) {
180+
completionListener.onFailure(e);
181+
}
182+
};
130183
}
131184
}
132185

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10+
import org.elasticsearch.core.Releasable;
1011
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1112

1213
import java.util.Iterator;
1314

14-
public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request> {
15+
public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request>, Releasable {
16+
int estimatedSize();
1517

1618
}

0 commit comments

Comments
 (0)