1010import org .elasticsearch .action .ActionListener ;
1111import org .elasticsearch .client .internal .Client ;
1212import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
13+ import org .elasticsearch .inference .TaskType ;
1314import org .elasticsearch .threadpool .ThreadPool ;
1415import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
16+ import org .elasticsearch .xpack .core .inference .action .UnifiedCompletionAction ;
1517
1618import java .util .ArrayList ;
1719import java .util .List ;
@@ -175,12 +177,12 @@ private class BulkInferenceRequest {
175177 * to the request iterator.
176178 * </p>
177179 *
178- * @return A BulkRequestItem if a request and permit are available, null otherwise
180+ * @return A BulkInferenceRequestItem if a request and permit are available, null otherwise
179181 */
180- private BulkRequestItem pollPendingRequest () {
182+ private BulkInferenceRequestItem <?> pollPendingRequest () {
181183 synchronized (requests ) {
182184 if (requests .hasNext ()) {
183- return new BulkRequestItem ( executionState . generateSeqNo (), requests . next ());
185+ return requests . next (). withSeqNo ( executionState . generateSeqNo ());
184186 }
185187 }
186188
@@ -226,22 +228,22 @@ private void executePendingRequests(int recursionDepth) {
226228 }
227229 return ;
228230 } else {
229- BulkRequestItem bulkRequestItem = pollPendingRequest ();
231+ BulkInferenceRequestItem <?> bulkRequestItem = pollPendingRequest ();
230232
231233 if (bulkRequestItem == null ) {
232234 // No more requests available
233235 // Release the permit we didn't used and stop processing
234236 permits .release ();
235237
236238 // Check if another bulk request is pending for execution.
237- BulkInferenceRequest nexBulkRequest = pendingBulkRequests .poll ();
239+ BulkInferenceRequest nextBulkRequest = pendingBulkRequests .poll ();
238240
239- while (nexBulkRequest == this ) {
240- nexBulkRequest = pendingBulkRequests .poll ();
241+ while (nextBulkRequest == this ) {
242+ nextBulkRequest = pendingBulkRequests .poll ();
241243 }
242244
243- if (nexBulkRequest != null ) {
244- executor .execute (nexBulkRequest ::executePendingRequests );
245+ if (nextBulkRequest != null ) {
246+ executor .execute (nextBulkRequest ::executePendingRequests );
245247 }
246248
247249 return ;
@@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) {
275277 // Response has already been sent
276278 // No need to continue processing this bulk.
277279 // Check if another bulk request is pending for execution.
278- BulkInferenceRequest nexBulkRequest = pendingBulkRequests .poll ();
279- if (nexBulkRequest != null ) {
280- executor .execute (nexBulkRequest ::executePendingRequests );
280+ BulkInferenceRequest nextBulkRequest = pendingBulkRequests .poll ();
281+ if (nextBulkRequest != null ) {
282+ executor .execute (nextBulkRequest ::executePendingRequests );
281283 }
282284 return ;
283285 }
@@ -298,26 +300,57 @@ private void executePendingRequests(int recursionDepth) {
298300 );
299301
300302 // Handle null requests (edge case in some iterators)
301- if (bulkRequestItem .request () == null ) {
303+ if (bulkRequestItem .inferenceRequest () == null ) {
302304 inferenceResponseListener .onResponse (null );
303305 return ;
304306 }
305307
306308 // Execute the inference request with proper origin context
307- executeAsyncWithOrigin (
308- client ,
309- INFERENCE_ORIGIN ,
310- InferenceAction .INSTANCE ,
311- bulkRequestItem .request (),
312- inferenceResponseListener
313- );
309+ if (bulkRequestItem .taskType () == TaskType .CHAT_COMPLETION ) {
310+ handleStreamingRequest (
311+ (UnifiedCompletionAction .Request ) bulkRequestItem .inferenceRequest (),
312+ inferenceResponseListener
313+ );
314+ } else {
315+ executeAsyncWithOrigin (
316+ client ,
317+ INFERENCE_ORIGIN ,
318+ InferenceAction .INSTANCE ,
319+ bulkRequestItem .inferenceRequest (),
320+ inferenceResponseListener
321+ );
322+ }
314323 }
315324 }
316325 } catch (Exception e ) {
317326 executionState .addFailure (e );
318327 }
319328 }
320329
330+ /**
331+ * Handles streaming inference requests for chat completion tasks.
332+ * <p>
333+ * This method executes UnifiedCompletionAction requests and sets up proper streaming
334+ * response handling through the BulkInferenceStreamingHandler. The streaming handler
335+ * manages the asynchronous stream processing and ensures responses are properly
336+ * delivered to the completion listener.
337+ * </p>
338+ *
339+ * @param request The UnifiedCompletionAction request to execute
340+ * @param listener The listener to receive the final aggregated response
341+ */
342+ private void handleStreamingRequest (UnifiedCompletionAction .Request request , ActionListener <InferenceAction .Response > listener ) {
343+ executeAsyncWithOrigin (
344+ client ,
345+ INFERENCE_ORIGIN ,
346+ UnifiedCompletionAction .INSTANCE ,
347+ request ,
348+ listener .delegateFailureAndWrap ((l , inferenceResponse ) -> {
349+ inferenceResponse .publisher ().subscribe (new BulkInferenceStreamingHandler (l ));
350+ })
351+ );
352+ }
353+
321354 /**
322355 * Processes and delivers buffered responses in order, ensuring proper sequencing.
323356 * <p>
@@ -360,20 +393,6 @@ private void onBulkCompletion() {
360393 }
361394 }
362395
363- /**
364- * Encapsulates an inference request with its associated sequence number.
365- * <p>
366- * The sequence number is used for ordering responses and tracking completion
367- * in the bulk execution state.
368- * </p>
369- *
370- * @param seqNo Unique sequence number for this request in the bulk operation
371- * @param request The actual inference request to execute
372- */
373- private record BulkRequestItem (long seqNo , InferenceAction .Request request ) {
374-
375- }
376-
377396 public static Factory factory (Client client ) {
378397 return inferenceRunnerConfig -> new BulkInferenceRunner (client , inferenceRunnerConfig .maxOutstandingBulkRequests ());
379398 }
0 commit comments