Skip to content

Commit 910e158

Browse files
committed
[ML] Append all data to Chat Completion buffer (#127658)
Moved the Chat Completion buffer into the StreamingUnifiedChatCompletionResults so that all Chat Completion responses can benefit from it. Chat Completions is meant to adhere to OpenAI as much as possible, and OpenAI only sends one response chunk at a time. All implementations of Chat Completions will now buffer. This fixes a bug where more than two chunks in a single item would be dropped, instead they are all added to the buffer. This fixes a bug where onComplete would omit trailing items in the buffer.
1 parent 2c6e3a4 commit 910e158

File tree

6 files changed

+246
-47
lines changed

6 files changed

+246
-47
lines changed

docs/changelog/127658.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127658
2+
summary: Append all data to Chat Completion buffer
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.Writeable;
12+
13+
import java.io.IOException;
14+
import java.util.ArrayDeque;
15+
import java.util.Deque;
16+
17+
public final class DequeUtils {
18+
19+
private DequeUtils() {
20+
// util functions only
21+
}
22+
23+
public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
24+
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
25+
}
26+
27+
public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
28+
if (thisDeque.size() != otherDeque.size()) {
29+
return false;
30+
}
31+
var thisIter = thisDeque.iterator();
32+
var otherIter = otherDeque.iterator();
33+
while (thisIter.hasNext() && otherIter.hasNext()) {
34+
if (thisIter.next().equals(otherIter.next()) == false) {
35+
return false;
36+
}
37+
}
38+
return true;
39+
}
40+
41+
public static int dequeHashCode(Deque<?> deque) {
42+
if (deque == null) {
43+
return 0;
44+
}
45+
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
46+
}
47+
48+
public static <T> Deque<T> of(T elem) {
49+
var deque = new ArrayDeque<T>(1);
50+
deque.offer(elem);
51+
return deque;
52+
}
53+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.inference.InferenceResults;
1616
import org.elasticsearch.inference.InferenceServiceResults;
1717
import org.elasticsearch.xcontent.ToXContent;
18+
import org.elasticsearch.xpack.core.inference.DequeUtils;
1819

1920
import java.io.IOException;
2021
import java.util.Collections;
@@ -23,15 +24,15 @@
2324
import java.util.List;
2425
import java.util.Map;
2526
import java.util.concurrent.Flow;
27+
import java.util.concurrent.LinkedBlockingDeque;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2629

2730
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;
2831

2932
/**
3033
* Chat Completion results that only contain a Flow.Publisher.
3134
*/
32-
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher)
33-
implements
34-
InferenceServiceResults {
35+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {
3536

3637
public static final String NAME = "chat_completion_chunk";
3738
public static final String MODEL_FIELD = "model";
@@ -54,6 +55,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
5455
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5556
public static final String TYPE_FIELD = "type";
5657

58+
/**
59+
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
60+
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
61+
*/
62+
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
63+
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
64+
AtomicBoolean onComplete = new AtomicBoolean();
65+
this.publisher = downstream -> {
66+
publisher.subscribe(new Flow.Subscriber<>() {
67+
@Override
68+
public void onSubscribe(Flow.Subscription subscription) {
69+
downstream.onSubscribe(new Flow.Subscription() {
70+
@Override
71+
public void request(long n) {
72+
var nextItem = buffer.poll();
73+
if (nextItem != null) {
74+
downstream.onNext(new Results(DequeUtils.of(nextItem)));
75+
} else if (onComplete.get()) {
76+
downstream.onComplete();
77+
} else {
78+
subscription.request(n);
79+
}
80+
}
81+
82+
@Override
83+
public void cancel() {
84+
subscription.cancel();
85+
}
86+
});
87+
}
88+
89+
@Override
90+
public void onNext(Results item) {
91+
var chunks = item.chunks();
92+
var firstItem = chunks.poll();
93+
chunks.forEach(buffer::offer);
94+
downstream.onNext(new Results(DequeUtils.of(firstItem)));
95+
}
96+
97+
@Override
98+
public void onError(Throwable throwable) {
99+
downstream.onError(throwable);
100+
}
101+
102+
@Override
103+
public void onComplete() {
104+
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
105+
if (onComplete.compareAndSet(false, true)) {
106+
if (buffer.isEmpty()) {
107+
downstream.onComplete();
108+
}
109+
}
110+
}
111+
});
112+
};
113+
}
114+
57115
@Override
58116
public boolean isStreaming() {
59117
return true;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
import java.util.ArrayDeque;
1919
import java.util.Deque;
2020
import java.util.List;
21+
import java.util.concurrent.Flow;
22+
import java.util.concurrent.atomic.AtomicBoolean;
23+
import java.util.concurrent.atomic.AtomicInteger;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
import java.util.function.Supplier;
26+
27+
import static org.mockito.ArgumentMatchers.any;
28+
import static org.mockito.Mockito.spy;
29+
import static org.mockito.Mockito.times;
30+
import static org.mockito.Mockito.verify;
2131

2232
public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
2333

@@ -195,4 +205,99 @@ public void testToolCallToXContentChunked() throws IOException {
195205
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
196206
}
197207

208+
public void testBufferedPublishing() {
209+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
210+
results.offer(randomChatCompletionChunk());
211+
results.offer(randomChatCompletionChunk());
212+
var completed = new AtomicBoolean();
213+
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
214+
downstream.onSubscribe(new Flow.Subscription() {
215+
@Override
216+
public void request(long n) {
217+
if (completed.compareAndSet(false, true)) {
218+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
219+
} else {
220+
downstream.onComplete();
221+
}
222+
}
223+
224+
@Override
225+
public void cancel() {
226+
fail("Cancel should never be called.");
227+
}
228+
});
229+
});
230+
231+
AtomicInteger counter = new AtomicInteger(0);
232+
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
233+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
234+
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
235+
@Override
236+
public void onSubscribe(Flow.Subscription subscription) {
237+
if (upstream.compareAndSet(null, subscription) == false) {
238+
fail("Upstream already set?!");
239+
}
240+
subscription.request(1);
241+
}
242+
243+
@Override
244+
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
245+
assertNotNull(item);
246+
counter.incrementAndGet();
247+
var sub = upstream.get();
248+
if (sub != null) {
249+
sub.request(1);
250+
} else {
251+
fail("Upstream not yet set?!");
252+
}
253+
}
254+
255+
@Override
256+
public void onError(Throwable throwable) {
257+
fail(throwable);
258+
}
259+
260+
@Override
261+
public void onComplete() {}
262+
}
263+
);
264+
streamingResults.publisher().subscribe(subscriber);
265+
verify(subscriber, times(2)).onNext(any());
266+
}
267+
268+
private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() {
269+
Supplier<String> randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5);
270+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
271+
randomAlphanumericOfLength(5),
272+
randomBoolean() ? null : randomList(randomInt(5), () -> {
273+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
274+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
275+
randomOptionalString.get(),
276+
randomOptionalString.get(),
277+
randomOptionalString.get(),
278+
randomBoolean() ? null : randomList(randomInt(5), () -> {
279+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
280+
randomInt(5),
281+
randomOptionalString.get(),
282+
randomBoolean()
283+
? null
284+
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
285+
randomOptionalString.get(),
286+
randomOptionalString.get()
287+
),
288+
randomOptionalString.get()
289+
);
290+
})
291+
),
292+
randomOptionalString.get(),
293+
randomInt(5)
294+
);
295+
}),
296+
randomAlphanumericOfLength(5),
297+
randomAlphanumericOfLength(5),
298+
randomBoolean()
299+
? null
300+
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5))
301+
);
302+
}
198303
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.rest.RestStatus;
3333
import org.elasticsearch.xcontent.ToXContentObject;
3434
import org.elasticsearch.xcontent.XContentBuilder;
35+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3536
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3637
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3738

@@ -198,21 +199,23 @@ public void cancel() {}
198199
"object": "chat.completion.chunk"
199200
}
200201
*/
201-
private ChunkedToXContent unifiedCompletionChunk(String delta) {
202-
return params -> ChunkedToXContentHelper.chunk(
203-
(b, p) -> b.startObject()
204-
.field("id", "id")
205-
.startArray("choices")
206-
.startObject()
207-
.startObject("delta")
208-
.field("content", delta)
209-
.endObject()
210-
.field("index", 0)
211-
.endObject()
212-
.endArray()
213-
.field("model", "gpt-4o-2024-08-06")
214-
.field("object", "chat.completion.chunk")
215-
.endObject()
202+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
203+
return new StreamingUnifiedChatCompletionResults.Results(
204+
DequeUtils.of(
205+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
206+
"id",
207+
List.of(
208+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
209+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
210+
null,
211+
0
212+
)
213+
),
214+
"gpt-4o-2024-08-06",
215+
"chat.completion.chunk",
216+
null
217+
)
218+
)
216219
);
217220
}
218221

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1312
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1413
import org.elasticsearch.xcontent.ConstructingObjectParser;
1514
import org.elasticsearch.xcontent.ParseField;
@@ -28,13 +27,14 @@
2827
import java.util.Deque;
2928
import java.util.Iterator;
3029
import java.util.List;
31-
import java.util.concurrent.LinkedBlockingDeque;
3230
import java.util.function.BiFunction;
3331

3432
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3533
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
3634

37-
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
35+
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
36+
Deque<ServerSentEvent>,
37+
StreamingUnifiedChatCompletionResults.Results> {
3838
public static final String FUNCTION_FIELD = "function";
3939
private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);
4040

@@ -60,22 +60,12 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
6060
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6161

6262
private final BiFunction<String, Exception, Exception> errorParser;
63-
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
6463
private volatile boolean previousEventWasError = false;
6564

6665
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6766
this.errorParser = errorParser;
6867
}
6968

70-
@Override
71-
protected void upstreamRequest(long n) {
72-
if (buffer.isEmpty()) {
73-
super.upstreamRequest(n);
74-
} else {
75-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
76-
}
77-
}
78-
7969
@Override
8070
protected void next(Deque<ServerSentEvent> item) throws Exception {
8171
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -101,15 +91,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
10191

10292
if (results.isEmpty()) {
10393
upstream().request(1);
104-
} else if (results.size() == 1) {
105-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10694
} else {
107-
// results > 1, but openai spec only wants 1 chunk per SSE event
108-
var firstItem = singleItem(results.poll());
109-
while (results.isEmpty() == false) {
110-
buffer.offer(results.poll());
111-
}
112-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
95+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
11396
}
11497
}
11598

@@ -302,12 +285,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
302285
}
303286
}
304287
}
305-
306-
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
307-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
308-
) {
309-
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
310-
deque.offer(result);
311-
return deque;
312-
}
313288
}

0 commit comments

Comments
 (0)