1010import org .elasticsearch .action .ActionListener ;
1111import org .elasticsearch .action .support .ThreadedActionListener ;
1212import org .elasticsearch .client .internal .Client ;
13- import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
1413import org .elasticsearch .threadpool .ThreadPool ;
1514import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
1615
1716import java .util .Queue ;
1817import java .util .Set ;
18+ import java .util .concurrent .ConcurrentHashMap ;
1919import java .util .concurrent .ConcurrentLinkedQueue ;
2020import java .util .concurrent .ExecutorService ;
2121import java .util .concurrent .Semaphore ;
@@ -42,40 +42,27 @@ public class BulkInferenceRunner {
4242
4343 private final Client client ;
4444 private final Semaphore permits ;
45+ private final int maxRunningTasks ;
4546 private final ExecutorService executor ;
4647
4748 /**
48- * Custom concurrent queue that prevents duplicate bulk requests from being queued .
49+ * Tracks bulk requests that are currently queued to prevent duplicates .
4950 * <p>
50- * This queue implementation ensures fairness among multiple concurrent bulk operations
51- * by preventing the same bulk request from being queued multiple times. It uses a
52- * backing concurrent set to track which requests are already queued .
51+ * This set ensures fairness among multiple concurrent bulk operations by preventing
52+ * the same bulk request from being queued multiple times. Uses ConcurrentHashMap.newKeySet()
53+ * for lock-free thread-safe operations .
5354 * </p>
5455 */
55- private final Queue <BulkInferenceRequest > pendingBulkRequests = new ConcurrentLinkedQueue <>() {
56- private final Set <BulkInferenceRequest > requests = ConcurrentCollections .newConcurrentSet ();
56+ private final Set <BulkInferenceRequest > trackedRequests = ConcurrentHashMap .newKeySet ();
5757
58- @ Override
59- public boolean offer (BulkInferenceRequest bulkInferenceRequest ) {
60- synchronized (requests ) {
61- if (requests .add (bulkInferenceRequest )) {
62- return super .offer (bulkInferenceRequest );
63- }
64- return false ; // Already exists, don't add duplicate
65- }
66- }
67-
68- @ Override
69- public BulkInferenceRequest poll () {
70- synchronized (requests ) {
71- BulkInferenceRequest request = super .poll ();
72- if (request != null ) {
73- requests .remove (request );
74- }
75- return request ;
76- }
77- }
78- };
58+ /**
59+ * Queue of pending bulk requests waiting for permit availability.
60+ * <p>
61+ * Works in conjunction with {@link #trackedRequests} to ensure no duplicate requests
62+ * are queued while maintaining lock-free concurrent access.
63+ * </p>
64+ */
65+ private final Queue <BulkInferenceRequest > pendingBulkRequests = new ConcurrentLinkedQueue <>();
7966
8067 /**
8168 * Constructs a new throttled inference runner with the specified configuration.
@@ -85,6 +72,7 @@ public BulkInferenceRequest poll() {
8572 */
8673 public BulkInferenceRunner (Client client , int maxRunningTasks ) {
8774 this .permits = new Semaphore (maxRunningTasks );
75+ this .maxRunningTasks = maxRunningTasks ;
8876 this .client = client ;
8977 this .executor = client .threadPool ().executor (ThreadPool .Names .SEARCH );
9078 }
@@ -142,7 +130,7 @@ private class BulkInferenceRequest {
142130 private final Consumer <BulkInferenceResponse > responseConsumer ;
143131 private final ActionListener <Void > completionListener ;
144132
145- private final BulkInferenceExecutionState executionState = new BulkInferenceExecutionState () ;
133+ private final BulkInferenceExecutionState executionState ;
146134 private final AtomicBoolean responseSent = new AtomicBoolean (false );
147135
148136 BulkInferenceRequest (
@@ -153,6 +141,15 @@ private class BulkInferenceRequest {
153141 this .requests = requests ;
154142 this .responseConsumer = responseConsumer ;
155143 this .completionListener = completionListener ;
144+
145+ // Initialize buffer capacity based on expected out-of-order responses.
146+ // Use the minimum of:
147+ // 1. Half of maxRunningTasks (typical out-of-order buffer size with good network conditions)
148+ // 2. Estimated request size (if smaller, cap at that)
149+ // This balances memory efficiency with avoiding rehashing for typical workloads.
150+ int estimatedSize = requests .estimatedSize ();
151+ int bufferCapacity = Math .max (1 , Math .min (estimatedSize , maxRunningTasks ) / 2 );
152+ this .executionState = new BulkInferenceExecutionState (bufferCapacity );
156153 }
157154
158155 /**
@@ -180,7 +177,7 @@ private BulkInferenceRequestItem pollPendingRequest() {
180177 * This method implements a continuation-based asynchronous pattern with the following features:
181178 * - Queue-based fairness: Multiple bulk requests can be queued and processed fairly
182179 * - Permit-based concurrency control: Limits concurrent inference requests using semaphores
183- * - Hybrid recursion strategy: Uses direct recursion for performance up to 100 levels,
180+ * - Hybrid recursion strategy: Uses direct recursion for performance up to 500 levels,
184181 * then switches to executor-based continuation to prevent stack overflow
185182 * - Duplicate prevention: Custom queue prevents the same bulk request from being queued multiple times
186183 * </p>
@@ -191,7 +188,7 @@ private BulkInferenceRequestItem pollPendingRequest() {
191188 * 3. Polls for the next available request from the iterator
192189 * 4. If no requests available, schedules the next queued bulk request
193190 * 5. Executes the request asynchronously with proper continuation handling
194- * 6. Uses hybrid recursion: direct calls up to 100 levels, executor-based beyond that
191+ * 6. Uses hybrid recursion: direct calls up to 500 levels, executor-based beyond that
195192 * </p>
196193 * <p>
197194 * The loop terminates when:
@@ -209,7 +206,10 @@ private void executePendingRequests(int recursionDepth) {
209206 while (executionState .finished () == false ) {
210207 if (permits .tryAcquire () == false ) {
211208 if (requests .hasNext ()) {
212- pendingBulkRequests .add (this );
209+ // Add to tracking set first to prevent duplicates
210+ if (trackedRequests .add (this )) {
211+ pendingBulkRequests .offer (this );
212+ }
213213 }
214214 return ;
215215 } else {
@@ -228,6 +228,10 @@ private void executePendingRequests(int recursionDepth) {
228228 }
229229
230230 if (nexBulkRequest != null ) {
231+ // Remove from tracking set since we're about to process it
232+ trackedRequests .remove (nexBulkRequest );
233+ // Execute the next bulk request with reset recursion depth
234+ // Use final variable for lambda capture
231235 executor .execute (nexBulkRequest ::executePendingRequests );
232236 }
233237
@@ -242,49 +246,54 @@ private void executePendingRequests(int recursionDepth) {
242246
243247 final ActionListener <InferenceAction .Response > inferenceResponseListener = new ThreadedActionListener <>(
244248 executor ,
245- ActionListener .runAfter (
246- ActionListener .wrap (
247- r -> executionState .onInferenceResponse (new BulkInferenceResponse (bulkInferenceRequestItem , r )),
248- e -> executionState .onInferenceException (bulkInferenceRequestItem .seqNo (), e )
249- ),
250- () -> {
251- // Release the permit we used
252- permits .release ();
253-
254- try {
255- synchronized (executionState ) {
256- persistPendingResponses ();
257- }
249+ ActionListener .runAfter (ActionListener .wrap (r -> {
250+ BulkInferenceResponse bulkResponse = new BulkInferenceResponse (bulkInferenceRequestItem , r );
251+ executionState .onInferenceResponse (bulkResponse );
252+ }, e -> executionState .onInferenceException (bulkInferenceRequestItem .seqNo (), e )), () -> {
253+ // Release the permit we used
254+ permits .release ();
255+
256+ try {
257+ synchronized (executionState ) {
258+ persistPendingResponses ();
259+ }
258260
259- if (executionState .finished () && responseSent .compareAndSet (false , true )) {
260- onBulkCompletion ();
261- }
261+ if (executionState .finished () && responseSent .compareAndSet (false , true )) {
262+ onBulkCompletion ();
263+ }
262264
263- if (responseSent .get ()) {
264- // Response has already been sent
265- // No need to continue processing this bulk.
266- // Check if another bulk request is pending for execution.
267- BulkInferenceRequest nexBulkRequest = pendingBulkRequests .poll ();
268- if (nexBulkRequest != null ) {
269- executor .execute (nexBulkRequest ::executePendingRequests );
270- }
271- return ;
265+ if (responseSent .get ()) {
266+ // Response has already been sent
267+ // No need to continue processing this bulk.
268+ // Check if another bulk request is pending for execution.
269+ BulkInferenceRequest nexBulkRequest = pendingBulkRequests .poll ();
270+ if (nexBulkRequest != null ) {
271+ // Remove from tracking set since we're about to process it
272+ trackedRequests .remove (nexBulkRequest );
273+ // Execute the next bulk request with reset recursion depth
274+ // Use final variable for lambda capture
275+ executor .execute (nexBulkRequest ::executePendingRequests );
272276 }
273- if (executionState .finished () == false ) {
274- // Execute any pending requests if any
275- if (recursionDepth > 100 ) {
276- executor .execute (this ::executePendingRequests );
277- } else {
278- this .executePendingRequests (recursionDepth + 1 );
279- }
280- }
281- } catch (Exception e ) {
282- if (responseSent .compareAndSet (false , true )) {
283- completionListener .onFailure (e );
277+ return ;
278+ }
279+ if (executionState .finished () == false ) {
280+ // Execute any pending requests if any
281+ if (recursionDepth > 500 ) {
282+ // Reset recursion depth by submitting to executor
283+ // This prevents unbounded stack growth while maintaining performance
284+ executor .execute (this ::executePendingRequests );
285+ } else {
286+ this .executePendingRequests (recursionDepth + 1 );
284287 }
285288 }
289+ } catch (Exception e ) {
290+ if (responseSent .compareAndSet (false , true )) {
291+ // Clean up tracking set before notifying failure
292+ trackedRequests .remove (BulkInferenceRequest .this );
293+ completionListener .onFailure (e );
294+ }
286295 }
287- )
296+ } )
288297 );
289298
290299 // Handle null requests (edge case in some iterators)
@@ -305,6 +314,8 @@ private void executePendingRequests(int recursionDepth) {
305314 }
306315 } catch (Exception e ) {
307316 executionState .addFailure (e );
317+ // Ensure cleanup on exception - remove from tracking set to prevent memory leak
318+ trackedRequests .remove (this );
308319 }
309320 }
310321
@@ -324,7 +335,9 @@ private void persistPendingResponses() {
324335 if (executionState .hasFailure () == false ) {
325336 try {
326337 BulkInferenceResponse response = executionState .fetchBufferedResponse (persistedSeqNo );
327- responseConsumer .accept (response );
338+ if (response != null ) {
339+ responseConsumer .accept (response );
340+ }
328341 } catch (Exception e ) {
329342 executionState .addFailure (e );
330343 }
@@ -335,18 +348,28 @@ private void persistPendingResponses() {
335348
336349 /**
337350 * Call the completion listener when all requests have completed.
351+ * Also ensures cleanup of this request from tracking structures to prevent memory leaks.
338352 */
339353 private void onBulkCompletion () {
340- if (executionState .hasFailure () == false ) {
341- try {
342- completionListener .onResponse (null );
343- return ;
344- } catch (Exception e ) {
345- executionState .addFailure (e );
354+ try {
355+ // Clean up tracking - remove this request from the tracking set
356+ // in case it was queued but never processed
357+ trackedRequests .remove (this );
358+
359+ if (executionState .hasFailure () == false ) {
360+ try {
361+ completionListener .onResponse (null );
362+ return ;
363+ } catch (Exception e ) {
364+ executionState .addFailure (e );
365+ }
346366 }
347- }
348367
349- completionListener .onFailure (executionState .getFailure ());
368+ completionListener .onFailure (executionState .getFailure ());
369+ } finally {
370+ // Ensure we're removed even if completion listener throws
371+ trackedRequests .remove (this );
372+ }
350373 }
351374 }
352375
0 commit comments