Skip to content

Commit 7cb2eae

Browse files
committed
Performance improvements.
1 parent 6fa285a commit 7cb2eae

File tree

6 files changed

+144
-88
lines changed

6 files changed

+144
-88
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ protected void performAsync(Page input, ActionListener<OngoingInferenceResult> l
8080
BulkInferenceRequestItemIterator requests = requests(input);
8181
listener = ActionListener.releaseBefore(requests, listener);
8282

83-
OngoingInferenceResult result = new OngoingInferenceResult(input, new ArrayList<>());
83+
// ✅ Pre-size based on estimated request count
84+
int estimatedSize = requests.estimatedSize();
85+
int initialCapacity = Math.max(10, Math.min(estimatedSize, 10000)); // Cap at 10k for safety
86+
OngoingInferenceResult result = new OngoingInferenceResult(input, new ArrayList<>(initialCapacity));
8487
listener = listener.delegateResponse((l, e) -> {
8588
Releasables.close(result);
8689
l.onFailure(e);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import org.elasticsearch.compute.operator.FailureCollector;
1111
import org.elasticsearch.index.seqno.LocalCheckpointTracker;
1212

13+
import java.util.HashMap;
1314
import java.util.Map;
14-
import java.util.concurrent.ConcurrentHashMap;
1515
import java.util.concurrent.atomic.AtomicBoolean;
1616

1717
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
@@ -26,8 +26,25 @@ public class BulkInferenceExecutionState {
2626
private final Map<Long, BulkInferenceResponse> bufferedResponses;
2727
private final AtomicBoolean finished = new AtomicBoolean(false);
2828

29+
/**
30+
* Creates a new execution state with default buffer capacity.
31+
*/
2932
public BulkInferenceExecutionState() {
30-
this.bufferedResponses = new ConcurrentHashMap<>();
33+
this(16);
34+
}
35+
36+
/**
37+
* Creates a new execution state with the specified initial buffer capacity.
38+
* <p>
39+
* The initial capacity should be sized based on the expected number of out-of-order responses.
40+
* A good heuristic is to use a fraction of maxRunningTasks, as that bounds the number of
41+
* concurrent in-flight responses that could arrive out-of-order.
42+
* </p>
43+
*
44+
* @param initialCapacity The initial capacity for the response buffer
45+
*/
46+
public BulkInferenceExecutionState(int initialCapacity) {
47+
this.bufferedResponses = new HashMap<>(initialCapacity);
3148
}
3249

3350
/**
@@ -68,9 +85,9 @@ public void markSeqNoAsPersisted(long seqNo) {
6885
}
6986

7087
/**
71-
* Add an inference response to the buffer and marks the corresponding sequence number as processed.
88+
* Buffers an inference response and marks the corresponding sequence number as processed.
7289
*
73-
* @param response The bulk inference response object
90+
* @param response The bulk inference response object
7491
*/
7592
public synchronized void onInferenceResponse(BulkInferenceResponse response) {
7693
if (response != null && failureCollector.hasFailure() == false) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java

Lines changed: 101 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.ThreadedActionListener;
1212
import org.elasticsearch.client.internal.Client;
13-
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1413
import org.elasticsearch.threadpool.ThreadPool;
1514
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1615

1716
import java.util.Queue;
1817
import java.util.Set;
18+
import java.util.concurrent.ConcurrentHashMap;
1919
import java.util.concurrent.ConcurrentLinkedQueue;
2020
import java.util.concurrent.ExecutorService;
2121
import java.util.concurrent.Semaphore;
@@ -42,40 +42,27 @@ public class BulkInferenceRunner {
4242

4343
private final Client client;
4444
private final Semaphore permits;
45+
private final int maxRunningTasks;
4546
private final ExecutorService executor;
4647

4748
/**
48-
* Custom concurrent queue that prevents duplicate bulk requests from being queued.
49+
* Tracks bulk requests that are currently queued to prevent duplicates.
4950
* <p>
50-
* This queue implementation ensures fairness among multiple concurrent bulk operations
51-
* by preventing the same bulk request from being queued multiple times. It uses a
52-
* backing concurrent set to track which requests are already queued.
51+
* This set ensures fairness among multiple concurrent bulk operations by preventing
52+
* the same bulk request from being queued multiple times. Uses ConcurrentHashMap.newKeySet()
53+
* for lock-free thread-safe operations.
5354
* </p>
5455
*/
55-
private final Queue<BulkInferenceRequest> pendingBulkRequests = new ConcurrentLinkedQueue<>() {
56-
private final Set<BulkInferenceRequest> requests = ConcurrentCollections.newConcurrentSet();
56+
private final Set<BulkInferenceRequest> trackedRequests = ConcurrentHashMap.newKeySet();
5757

58-
@Override
59-
public boolean offer(BulkInferenceRequest bulkInferenceRequest) {
60-
synchronized (requests) {
61-
if (requests.add(bulkInferenceRequest)) {
62-
return super.offer(bulkInferenceRequest);
63-
}
64-
return false; // Already exists, don't add duplicate
65-
}
66-
}
67-
68-
@Override
69-
public BulkInferenceRequest poll() {
70-
synchronized (requests) {
71-
BulkInferenceRequest request = super.poll();
72-
if (request != null) {
73-
requests.remove(request);
74-
}
75-
return request;
76-
}
77-
}
78-
};
58+
/**
59+
* Queue of pending bulk requests waiting for permit availability.
60+
* <p>
61+
* Works in conjunction with {@link #trackedRequests} to ensure no duplicate requests
62+
* are queued while maintaining lock-free concurrent access.
63+
* </p>
64+
*/
65+
private final Queue<BulkInferenceRequest> pendingBulkRequests = new ConcurrentLinkedQueue<>();
7966

8067
/**
8168
* Constructs a new throttled inference runner with the specified configuration.
@@ -85,6 +72,7 @@ public BulkInferenceRequest poll() {
8572
*/
8673
public BulkInferenceRunner(Client client, int maxRunningTasks) {
8774
this.permits = new Semaphore(maxRunningTasks);
75+
this.maxRunningTasks = maxRunningTasks;
8876
this.client = client;
8977
this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
9078
}
@@ -142,7 +130,7 @@ private class BulkInferenceRequest {
142130
private final Consumer<BulkInferenceResponse> responseConsumer;
143131
private final ActionListener<Void> completionListener;
144132

145-
private final BulkInferenceExecutionState executionState = new BulkInferenceExecutionState();
133+
private final BulkInferenceExecutionState executionState;
146134
private final AtomicBoolean responseSent = new AtomicBoolean(false);
147135

148136
BulkInferenceRequest(
@@ -153,6 +141,15 @@ private class BulkInferenceRequest {
153141
this.requests = requests;
154142
this.responseConsumer = responseConsumer;
155143
this.completionListener = completionListener;
144+
145+
// Initialize buffer capacity based on expected out-of-order responses.
146+
// Use the minimum of:
147+
// 1. Half of maxRunningTasks (typical out-of-order buffer size with good network conditions)
148+
// 2. Estimated request size (if smaller, cap at that)
149+
// This balances memory efficiency with avoiding rehashing for typical workloads.
150+
int estimatedSize = requests.estimatedSize();
151+
int bufferCapacity = Math.max(1, Math.min(estimatedSize, maxRunningTasks) / 2);
152+
this.executionState = new BulkInferenceExecutionState(bufferCapacity);
156153
}
157154

158155
/**
@@ -180,7 +177,7 @@ private BulkInferenceRequestItem pollPendingRequest() {
180177
* This method implements a continuation-based asynchronous pattern with the following features:
181178
* - Queue-based fairness: Multiple bulk requests can be queued and processed fairly
182179
* - Permit-based concurrency control: Limits concurrent inference requests using semaphores
183-
* - Hybrid recursion strategy: Uses direct recursion for performance up to 100 levels,
180+
* - Hybrid recursion strategy: Uses direct recursion for performance up to 500 levels,
184181
* then switches to executor-based continuation to prevent stack overflow
185182
* - Duplicate prevention: Custom queue prevents the same bulk request from being queued multiple times
186183
* </p>
@@ -191,7 +188,7 @@ private BulkInferenceRequestItem pollPendingRequest() {
191188
* 3. Polls for the next available request from the iterator
192189
* 4. If no requests available, schedules the next queued bulk request
193190
* 5. Executes the request asynchronously with proper continuation handling
194-
* 6. Uses hybrid recursion: direct calls up to 100 levels, executor-based beyond that
191+
* 6. Uses hybrid recursion: direct calls up to 500 levels, executor-based beyond that
195192
* </p>
196193
* <p>
197194
* The loop terminates when:
@@ -209,7 +206,10 @@ private void executePendingRequests(int recursionDepth) {
209206
while (executionState.finished() == false) {
210207
if (permits.tryAcquire() == false) {
211208
if (requests.hasNext()) {
212-
pendingBulkRequests.add(this);
209+
// Add to tracking set first to prevent duplicates
210+
if (trackedRequests.add(this)) {
211+
pendingBulkRequests.offer(this);
212+
}
213213
}
214214
return;
215215
} else {
@@ -228,6 +228,10 @@ private void executePendingRequests(int recursionDepth) {
228228
}
229229

230230
if (nexBulkRequest != null) {
231+
// Remove from tracking set since we're about to process it
232+
trackedRequests.remove(nexBulkRequest);
233+
// Execute the next bulk request with reset recursion depth
234+
// Use final variable for lambda capture
231235
executor.execute(nexBulkRequest::executePendingRequests);
232236
}
233237

@@ -242,49 +246,54 @@ private void executePendingRequests(int recursionDepth) {
242246

243247
final ActionListener<InferenceAction.Response> inferenceResponseListener = new ThreadedActionListener<>(
244248
executor,
245-
ActionListener.runAfter(
246-
ActionListener.wrap(
247-
r -> executionState.onInferenceResponse(new BulkInferenceResponse(bulkInferenceRequestItem, r)),
248-
e -> executionState.onInferenceException(bulkInferenceRequestItem.seqNo(), e)
249-
),
250-
() -> {
251-
// Release the permit we used
252-
permits.release();
253-
254-
try {
255-
synchronized (executionState) {
256-
persistPendingResponses();
257-
}
249+
ActionListener.runAfter(ActionListener.wrap(r -> {
250+
BulkInferenceResponse bulkResponse = new BulkInferenceResponse(bulkInferenceRequestItem, r);
251+
executionState.onInferenceResponse(bulkResponse);
252+
}, e -> executionState.onInferenceException(bulkInferenceRequestItem.seqNo(), e)), () -> {
253+
// Release the permit we used
254+
permits.release();
255+
256+
try {
257+
synchronized (executionState) {
258+
persistPendingResponses();
259+
}
258260

259-
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
260-
onBulkCompletion();
261-
}
261+
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
262+
onBulkCompletion();
263+
}
262264

263-
if (responseSent.get()) {
264-
// Response has already been sent
265-
// No need to continue processing this bulk.
266-
// Check if another bulk request is pending for execution.
267-
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
268-
if (nexBulkRequest != null) {
269-
executor.execute(nexBulkRequest::executePendingRequests);
270-
}
271-
return;
265+
if (responseSent.get()) {
266+
// Response has already been sent
267+
// No need to continue processing this bulk.
268+
// Check if another bulk request is pending for execution.
269+
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
270+
if (nexBulkRequest != null) {
271+
// Remove from tracking set since we're about to process it
272+
trackedRequests.remove(nexBulkRequest);
273+
// Execute the next bulk request with reset recursion depth
274+
// Use final variable for lambda capture
275+
executor.execute(nexBulkRequest::executePendingRequests);
272276
}
273-
if (executionState.finished() == false) {
274-
// Execute any pending requests if any
275-
if (recursionDepth > 100) {
276-
executor.execute(this::executePendingRequests);
277-
} else {
278-
this.executePendingRequests(recursionDepth + 1);
279-
}
280-
}
281-
} catch (Exception e) {
282-
if (responseSent.compareAndSet(false, true)) {
283-
completionListener.onFailure(e);
277+
return;
278+
}
279+
if (executionState.finished() == false) {
280+
// Execute any pending requests if any
281+
if (recursionDepth > 500) {
282+
// Reset recursion depth by submitting to executor
283+
// This prevents unbounded stack growth while maintaining performance
284+
executor.execute(this::executePendingRequests);
285+
} else {
286+
this.executePendingRequests(recursionDepth + 1);
284287
}
285288
}
289+
} catch (Exception e) {
290+
if (responseSent.compareAndSet(false, true)) {
291+
// Clean up tracking set before notifying failure
292+
trackedRequests.remove(BulkInferenceRequest.this);
293+
completionListener.onFailure(e);
294+
}
286295
}
287-
)
296+
})
288297
);
289298

290299
// Handle null requests (edge case in some iterators)
@@ -305,6 +314,8 @@ private void executePendingRequests(int recursionDepth) {
305314
}
306315
} catch (Exception e) {
307316
executionState.addFailure(e);
317+
// Ensure cleanup on exception - remove from tracking set to prevent memory leak
318+
trackedRequests.remove(this);
308319
}
309320
}
310321

@@ -324,7 +335,9 @@ private void persistPendingResponses() {
324335
if (executionState.hasFailure() == false) {
325336
try {
326337
BulkInferenceResponse response = executionState.fetchBufferedResponse(persistedSeqNo);
327-
responseConsumer.accept(response);
338+
if (response != null) {
339+
responseConsumer.accept(response);
340+
}
328341
} catch (Exception e) {
329342
executionState.addFailure(e);
330343
}
@@ -335,18 +348,28 @@ private void persistPendingResponses() {
335348

336349
/**
337350
* Call the completion listener when all requests have completed.
351+
* Also ensures cleanup of this request from tracking structures to prevent memory leaks.
338352
*/
339353
private void onBulkCompletion() {
340-
if (executionState.hasFailure() == false) {
341-
try {
342-
completionListener.onResponse(null);
343-
return;
344-
} catch (Exception e) {
345-
executionState.addFailure(e);
354+
try {
355+
// Clean up tracking - remove this request from the tracking set
356+
// in case it was queued but never processed
357+
trackedRequests.remove(this);
358+
359+
if (executionState.hasFailure() == false) {
360+
try {
361+
completionListener.onResponse(null);
362+
return;
363+
} catch (Exception e) {
364+
executionState.addFailure(e);
365+
}
346366
}
347-
}
348367

349-
completionListener.onFailure(executionState.getFailure());
368+
completionListener.onFailure(executionState.getFailure());
369+
} finally {
370+
// Ensure we're removed even if completion listener throws
371+
trackedRequests.remove(this);
372+
}
350373
}
351374
}
352375

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionInferenceRequestIterator.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class CompletionInferenceRequestIterator implements BulkInferenceRequestItemIter
3535

3636
private int currentPos = 0;
3737

38+
private static final int[] SHAPE_SINGLE_ONE = new int[] { 1 };
39+
private static final int[] SHAPE_SINGLE_ZERO = new int[] { 0 };
40+
3841
/**
3942
* Constructs a new iterator from the given block of prompts.
4043
*
@@ -77,8 +80,7 @@ public BulkInferenceRequestItem next() {
7780
}
7881

7982
// Create shape array of exact size
80-
int[] shape = Arrays.copyOf(shapeBuffer, shapeSize);
81-
return new BulkInferenceRequestItem(inferenceRequest(nextPrompt), shape);
83+
return new BulkInferenceRequestItem(inferenceRequest(nextPrompt), createShape());
8284
}
8385

8486
private void addToShape(int value) {
@@ -89,6 +91,14 @@ private void addToShape(int value) {
8991
shapeBuffer[shapeSize++] = value;
9092
}
9193

94+
private int[] createShape() {
95+
if (shapeSize == 1) {
96+
return shapeBuffer[0] == 1 ? SHAPE_SINGLE_ONE : SHAPE_SINGLE_ZERO;
97+
}
98+
99+
return Arrays.copyOf(shapeBuffer, shapeSize);
100+
}
101+
92102
/**
93103
* Wraps a single prompt string into an {@link InferenceAction.Request}.
94104
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults.Embedding<
105105
private static float[] toFloatArray(byte[] values) {
106106
float[] floatArray = new float[values.length];
107107
for (int i = 0; i < values.length; i++) {
108-
floatArray[i] = ((Byte) values[i]).floatValue();
108+
floatArray[i] = (float) values[i];
109109
}
110110
return floatArray;
111111
}

0 commit comments

Comments
 (0)