Skip to content

Commit bc6eaaa

Browse files
committed
Implements streaming support in the completion operator.
1 parent 7d986e6 commit bc6eaaa

File tree

16 files changed

+599
-102
lines changed

16 files changed

+599
-102
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.bulk;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
12+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
13+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
14+
15+
import java.util.Objects;
16+
17+
sealed public interface BulkInferenceRequestItem<T extends BaseInferenceActionRequest> permits
18+
BulkInferenceRequestItem.AbstractBulkInferenceRequestItem {
19+
20+
TaskType taskType();
21+
22+
T inferenceRequest();
23+
24+
BulkInferenceRequestItem<T> withSeqNo(long seqNo);
25+
26+
Long seqNo();
27+
28+
static InferenceRequestItem from(InferenceAction.Request request) {
29+
return new InferenceRequestItem(request);
30+
}
31+
32+
static ChatCompletionRequestItem from(UnifiedCompletionAction.Request request) {
33+
return new ChatCompletionRequestItem(request);
34+
}
35+
36+
abstract sealed class AbstractBulkInferenceRequestItem<T extends BaseInferenceActionRequest> implements BulkInferenceRequestItem<T>
37+
permits InferenceRequestItem, ChatCompletionRequestItem {
38+
private final T request;
39+
private final Long seqNo;
40+
41+
protected AbstractBulkInferenceRequestItem(T request) {
42+
this(request, null);
43+
}
44+
45+
protected AbstractBulkInferenceRequestItem(T request, Long seqNo) {
46+
this.request = request;
47+
this.seqNo = seqNo;
48+
}
49+
50+
@Override
51+
public T inferenceRequest() {
52+
return request;
53+
}
54+
55+
@Override
56+
public Long seqNo() {
57+
return seqNo;
58+
}
59+
60+
@Override
61+
public boolean equals(Object o) {
62+
if (o == null || getClass() != o.getClass()) return false;
63+
AbstractBulkInferenceRequestItem<?> that = (AbstractBulkInferenceRequestItem<?>) o;
64+
return Objects.equals(request, that.request) && Objects.equals(seqNo, that.seqNo);
65+
}
66+
67+
@Override
68+
public int hashCode() {
69+
return Objects.hash(request, seqNo);
70+
}
71+
72+
@Override
73+
public TaskType taskType() {
74+
return request.getTaskType();
75+
}
76+
}
77+
78+
final class InferenceRequestItem extends AbstractBulkInferenceRequestItem<InferenceAction.Request> {
79+
private InferenceRequestItem(InferenceAction.Request request) {
80+
this(request, null);
81+
}
82+
83+
private InferenceRequestItem(InferenceAction.Request request, Long seqNo) {
84+
super(request, seqNo);
85+
}
86+
87+
@Override
88+
public InferenceRequestItem withSeqNo(long seqNo) {
89+
return new InferenceRequestItem(inferenceRequest(), seqNo);
90+
}
91+
}
92+
93+
final class ChatCompletionRequestItem extends AbstractBulkInferenceRequestItem<UnifiedCompletionAction.Request> {
94+
95+
private ChatCompletionRequestItem(UnifiedCompletionAction.Request request) {
96+
this(request, null);
97+
}
98+
99+
private ChatCompletionRequestItem(UnifiedCompletionAction.Request request, Long seqNo) {
100+
super(request, seqNo);
101+
}
102+
103+
@Override
104+
public TaskType taskType() {
105+
return TaskType.CHAT_COMPLETION;
106+
}
107+
108+
@Override
109+
public ChatCompletionRequestItem withSeqNo(long seqNo) {
110+
return new ChatCompletionRequestItem(inferenceRequest(), seqNo);
111+
}
112+
}
113+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88
package org.elasticsearch.xpack.esql.inference.bulk;
99

1010
import org.elasticsearch.core.Releasable;
11-
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1211

1312
import java.util.Iterator;
1413

15-
public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request>, Releasable {
14+
public interface BulkInferenceRequestIterator extends Iterator<BulkInferenceRequestItem<?>>, Releasable {
1615

1716
/**
1817
* Returns an estimate of the number of requests that will be produced.
1918
*
20-
* <p>This is typically used to pre-allocate buffers or output to th appropriate size.</p>
19+
* <p>This is typically used to pre-allocate buffers or output to the appropriate size.</p>
2120
*/
2221
int estimatedSize();
23-
2422
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.client.internal.Client;
1212
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
13+
import org.elasticsearch.inference.TaskType;
1314
import org.elasticsearch.threadpool.ThreadPool;
1415
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
16+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
1517

1618
import java.util.ArrayList;
1719
import java.util.List;
@@ -175,12 +177,12 @@ private class BulkInferenceRequest {
175177
* to the request iterator.
176178
* </p>
177179
*
178-
* @return A BulkRequestItem if a request and permit are available, null otherwise
180+
* @return A BulkInferenceRequestItem if a request and permit are available, null otherwise
179181
*/
180-
private BulkRequestItem pollPendingRequest() {
182+
private BulkInferenceRequestItem<?> pollPendingRequest() {
181183
synchronized (requests) {
182184
if (requests.hasNext()) {
183-
return new BulkRequestItem(executionState.generateSeqNo(), requests.next());
185+
return requests.next().withSeqNo(executionState.generateSeqNo());
184186
}
185187
}
186188

@@ -226,22 +228,22 @@ private void executePendingRequests(int recursionDepth) {
226228
}
227229
return;
228230
} else {
229-
BulkRequestItem bulkRequestItem = pollPendingRequest();
231+
BulkInferenceRequestItem<?> bulkRequestItem = pollPendingRequest();
230232

231233
if (bulkRequestItem == null) {
232234
// No more requests available
233235
// Release the permit we didn't used and stop processing
234236
permits.release();
235237

236238
// Check if another bulk request is pending for execution.
237-
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
239+
BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
238240

239-
while (nexBulkRequest == this) {
240-
nexBulkRequest = pendingBulkRequests.poll();
241+
while (nextBulkRequest == this) {
242+
nextBulkRequest = pendingBulkRequests.poll();
241243
}
242244

243-
if (nexBulkRequest != null) {
244-
executor.execute(nexBulkRequest::executePendingRequests);
245+
if (nextBulkRequest != null) {
246+
executor.execute(nextBulkRequest::executePendingRequests);
245247
}
246248

247249
return;
@@ -275,9 +277,9 @@ private void executePendingRequests(int recursionDepth) {
275277
// Response has already been sent
276278
// No need to continue processing this bulk.
277279
// Check if another bulk request is pending for execution.
278-
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
279-
if (nexBulkRequest != null) {
280-
executor.execute(nexBulkRequest::executePendingRequests);
280+
BulkInferenceRequest nextBulkRequest = pendingBulkRequests.poll();
281+
if (nextBulkRequest != null) {
282+
executor.execute(nextBulkRequest::executePendingRequests);
281283
}
282284
return;
283285
}
@@ -298,26 +300,57 @@ private void executePendingRequests(int recursionDepth) {
298300
);
299301

300302
// Handle null requests (edge case in some iterators)
301-
if (bulkRequestItem.request() == null) {
303+
if (bulkRequestItem.inferenceRequest() == null) {
302304
inferenceResponseListener.onResponse(null);
303305
return;
304306
}
305307

306308
// Execute the inference request with proper origin context
307-
executeAsyncWithOrigin(
308-
client,
309-
INFERENCE_ORIGIN,
310-
InferenceAction.INSTANCE,
311-
bulkRequestItem.request(),
312-
inferenceResponseListener
313-
);
309+
if (bulkRequestItem.taskType() == TaskType.CHAT_COMPLETION) {
310+
handleStreamingRequest(
311+
(UnifiedCompletionAction.Request) bulkRequestItem.inferenceRequest(),
312+
inferenceResponseListener
313+
);
314+
} else {
315+
executeAsyncWithOrigin(
316+
client,
317+
INFERENCE_ORIGIN,
318+
InferenceAction.INSTANCE,
319+
bulkRequestItem.inferenceRequest(),
320+
inferenceResponseListener
321+
);
322+
}
314323
}
315324
}
316325
} catch (Exception e) {
317326
executionState.addFailure(e);
318327
}
319328
}
320329

330+
/**
331+
* Handles streaming inference requests for chat completion tasks.
332+
* <p>
333+
* This method executes UnifiedCompletionAction requests and sets up proper streaming
334+
* response handling through the BulkInferenceStreamingHandler. The streaming handler
335+
* manages the asynchronous stream processing and ensures responses are properly
336+
* delivered to the completion listener.
337+
* </p>
338+
*
339+
* @param request The UnifiedCompletionAction request to execute
340+
* @param listener The listener to receive the final aggregated response
341+
*/
342+
private void handleStreamingRequest(UnifiedCompletionAction.Request request, ActionListener<InferenceAction.Response> listener) {
343+
executeAsyncWithOrigin(
344+
client,
345+
INFERENCE_ORIGIN,
346+
UnifiedCompletionAction.INSTANCE,
347+
request,
348+
listener.delegateFailureAndWrap((l, inferenceResponse) -> {
349+
inferenceResponse.publisher().subscribe(new BulkInferenceStreamingHandler(l));
350+
})
351+
);
352+
}
353+
321354
/**
322355
* Processes and delivers buffered responses in order, ensuring proper sequencing.
323356
* <p>
@@ -360,20 +393,6 @@ private void onBulkCompletion() {
360393
}
361394
}
362395

363-
/**
364-
* Encapsulates an inference request with its associated sequence number.
365-
* <p>
366-
* The sequence number is used for ordering responses and tracking completion
367-
* in the bulk execution state.
368-
* </p>
369-
*
370-
* @param seqNo Unique sequence number for this request in the bulk operation
371-
* @param request The actual inference request to execute
372-
*/
373-
private record BulkRequestItem(long seqNo, InferenceAction.Request request) {
374-
375-
}
376-
377396
public static Factory factory(Client client) {
378397
return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests());
379398
}

0 commit comments

Comments
 (0)