Skip to content

Commit 875f543

Browse files
committed
[ML] Append all data to Chat Completion buffer
Moved the Chat Completion buffer into the StreamingUnifiedChatCompletionResults so that all Chat Completion responses can benefit from it. Chat Completions is meant to adhere to OpenAI as much as possible, and OpenAI only sends one response chunk at a time. All implementations of Chat Completions will now buffer. This fixes a bug where more than two chunks in a single item would be dropped, instead they are all added to the buffer. This fixes a bug where onComplete would omit trailing items in the buffer.
1 parent 5f256cc commit 875f543

File tree

4 files changed

+146
-59
lines changed

4 files changed

+146
-59
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
1717
import org.elasticsearch.inference.InferenceServiceResults;
1818
import org.elasticsearch.xcontent.ToXContent;
19+
import org.elasticsearch.xpack.core.inference.DequeUtils;
1920

2021
import java.io.IOException;
2122
import java.util.Collections;
2223
import java.util.Deque;
2324
import java.util.Iterator;
2425
import java.util.List;
2526
import java.util.concurrent.Flow;
27+
import java.util.concurrent.LinkedBlockingDeque;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2629

2730
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;
2831
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
@@ -32,7 +35,7 @@
3235
/**
3336
* Chat Completion results that only contain a Flow.Publisher.
3437
*/
35-
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
38+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher)
3639
implements
3740
InferenceServiceResults {
3841

@@ -57,6 +60,64 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Inf
5760
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5861
public static final String TYPE_FIELD = "type";
5962

63+
/**
64+
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
65+
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
66+
*/
67+
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
68+
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
69+
AtomicBoolean onComplete = new AtomicBoolean();
70+
this.publisher = downstream -> {
71+
publisher.subscribe(new Flow.Subscriber<>() {
72+
@Override
73+
public void onSubscribe(Flow.Subscription subscription) {
74+
downstream.onSubscribe(new Flow.Subscription() {
75+
@Override
76+
public void request(long n) {
77+
if (buffer.isEmpty()) {
78+
if (onComplete.get()) {
79+
downstream.onComplete();
80+
} else {
81+
subscription.request(n);
82+
}
83+
} else {
84+
downstream.onNext(new Results(DequeUtils.of(buffer.poll())));
85+
}
86+
}
87+
88+
@Override
89+
public void cancel() {
90+
subscription.cancel();
91+
}
92+
});
93+
}
94+
95+
@Override
96+
public void onNext(Results item) {
97+
var chunks = item.chunks();
98+
var firstItem = chunks.poll();
99+
chunks.forEach(buffer::offer);
100+
downstream.onNext(new Results(DequeUtils.of(firstItem)));
101+
}
102+
103+
@Override
104+
public void onError(Throwable throwable) {
105+
downstream.onError(throwable);
106+
}
107+
108+
@Override
109+
public void onComplete() {
110+
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
111+
if (onComplete.compareAndSet(false, true)) {
112+
if (buffer.isEmpty()) {
113+
downstream.onComplete();
114+
}
115+
}
116+
}
117+
});
118+
};
119+
}
120+
60121
@Override
61122
public boolean isStreaming() {
62123
return true;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@
1919
import java.util.ArrayDeque;
2020
import java.util.Deque;
2121
import java.util.List;
22+
import java.util.concurrent.Flow;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicInteger;
25+
import java.util.concurrent.atomic.AtomicReference;
2226
import java.util.function.Supplier;
2327

28+
import static org.hamcrest.Matchers.equalTo;
29+
2430
public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
2531
StreamingUnifiedChatCompletionResults.Results> {
2632

@@ -198,6 +204,64 @@ public void testToolCallToXContentChunked() throws IOException {
198204
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
199205
}
200206

207+
public void testBufferedPublishing() {
208+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
209+
results.offer(randomChatCompletionChunk());
210+
results.offer(randomChatCompletionChunk());
211+
var completed = new AtomicBoolean();
212+
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
213+
downstream.onSubscribe(new Flow.Subscription() {
214+
@Override
215+
public void request(long n) {
216+
if(completed.compareAndSet(false, true)) {
217+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
218+
} else {
219+
downstream.onComplete();
220+
}
221+
}
222+
223+
@Override
224+
public void cancel() {
225+
fail("Cancel should never be called.");
226+
}
227+
});
228+
});
229+
230+
AtomicInteger counter = new AtomicInteger(0);
231+
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
232+
streamingResults.publisher().subscribe(new Flow.Subscriber<>() {
233+
@Override
234+
public void onSubscribe(Flow.Subscription subscription) {
235+
if(upstream.compareAndSet(null, subscription) == false) {
236+
fail("Upstream already set?!");
237+
}
238+
subscription.request(1);
239+
}
240+
241+
@Override
242+
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
243+
assertNotNull(item);
244+
counter.incrementAndGet();
245+
var sub = upstream.get();
246+
if(sub != null) {
247+
sub.request(1);
248+
} else {
249+
fail("Upstream not yet set?!");
250+
}
251+
}
252+
253+
@Override
254+
public void onError(Throwable throwable) {
255+
fail(throwable);
256+
}
257+
258+
@Override
259+
public void onComplete() {
260+
}
261+
});
262+
assertThat(counter.get(), equalTo(2));
263+
}
264+
201265
@Override
202266
protected Writeable.Reader<StreamingUnifiedChatCompletionResults.Results> instanceReader() {
203267
return StreamingUnifiedChatCompletionResults.Results::new;

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xcontent.ToXContent;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3738
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3839
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3940
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
@@ -256,37 +257,24 @@ public void cancel() {}
256257
"object": "chat.completion.chunk"
257258
}
258259
*/
259-
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
260-
return new InferenceServiceResults.Result() {
261-
@Override
262-
public String getWriteableName() {
263-
return "test_unifiedCompletionChunk";
264-
}
265-
266-
@Override
267-
public void writeTo(StreamOutput out) throws IOException {
268-
out.writeString(delta);
269-
}
270-
271-
@Override
272-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
273-
return ChunkedToXContentHelper.chunk(
274-
(b, p) -> b.startObject()
275-
.field("id", "id")
276-
.startArray("choices")
277-
.startObject()
278-
.startObject("delta")
279-
.field("content", delta)
280-
.endObject()
281-
.field("index", 0)
282-
.endObject()
283-
.endArray()
284-
.field("model", "gpt-4o-2024-08-06")
285-
.field("object", "chat.completion.chunk")
286-
.endObject()
287-
);
288-
}
289-
};
260+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
261+
return new StreamingUnifiedChatCompletionResults.Results(
262+
DequeUtils.of(
263+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
264+
"chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF",
265+
List.of(
266+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
267+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
268+
null,
269+
0
270+
)
271+
),
272+
"gpt-4o-2024-08-06",
273+
"chat.completion.chunk",
274+
null
275+
)
276+
)
277+
);
290278
}
291279

292280
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.Deque;
2727
import java.util.Iterator;
2828
import java.util.List;
29-
import java.util.concurrent.LinkedBlockingDeque;
3029
import java.util.function.BiFunction;
3130

3231
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
@@ -60,21 +59,11 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
6059
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6160

6261
private final BiFunction<String, Exception, Exception> errorParser;
63-
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
6462

6563
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6664
this.errorParser = errorParser;
6765
}
6866

69-
@Override
70-
protected void upstreamRequest(long n) {
71-
if (buffer.isEmpty()) {
72-
super.upstreamRequest(n);
73-
} else {
74-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
75-
}
76-
}
77-
7867
@Override
7968
protected void next(Deque<ServerSentEvent> item) throws Exception {
8069
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -96,15 +85,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
9685

9786
if (results.isEmpty()) {
9887
upstream().request(1);
99-
} else if (results.size() == 1) {
100-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10188
} else {
102-
// results > 1, but openai spec only wants 1 chunk per SSE event
103-
var firstItem = singleItem(results.poll());
104-
while (results.isEmpty() == false) {
105-
buffer.offer(results.poll());
106-
}
107-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
89+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10890
}
10991
}
11092

@@ -297,12 +279,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
297279
}
298280
}
299281
}
300-
301-
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
302-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
303-
) {
304-
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
305-
deque.offer(result);
306-
return deque;
307-
}
308282
}

0 commit comments

Comments
 (0)