Skip to content

Commit b82430d

Browse files
authored
[ML] Append all data to Chat Completion buffer (#127658) (#128134) (#128164)
Moved the Chat Completion buffer into the StreamingUnifiedChatCompletionResults so that all Chat Completion responses can benefit from it. Chat Completions is meant to adhere to OpenAI as much as possible, and OpenAI only sends one response chunk at a time. All implementations of Chat Completions will now buffer. This fixes a bug where more than two chunks in a single item would be dropped, instead they are all added to the buffer. This fixes a bug where onComplete would omit trailing items in the buffer.
1 parent b26a277 commit b82430d

File tree

6 files changed

+246
-47
lines changed

6 files changed

+246
-47
lines changed

docs/changelog/127658.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127658
2+
summary: Append all data to Chat Completion buffer
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.Writeable;
12+
13+
import java.io.IOException;
14+
import java.util.ArrayDeque;
15+
import java.util.Deque;
16+
17+
public final class DequeUtils {
18+
19+
private DequeUtils() {
20+
// util functions only
21+
}
22+
23+
public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
24+
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
25+
}
26+
27+
public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
28+
if (thisDeque.size() != otherDeque.size()) {
29+
return false;
30+
}
31+
var thisIter = thisDeque.iterator();
32+
var otherIter = otherDeque.iterator();
33+
while (thisIter.hasNext() && otherIter.hasNext()) {
34+
if (thisIter.next().equals(otherIter.next()) == false) {
35+
return false;
36+
}
37+
}
38+
return true;
39+
}
40+
41+
public static int dequeHashCode(Deque<?> deque) {
42+
if (deque == null) {
43+
return 0;
44+
}
45+
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
46+
}
47+
48+
public static <T> Deque<T> of(T elem) {
49+
var deque = new ArrayDeque<T>(1);
50+
deque.offer(elem);
51+
return deque;
52+
}
53+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.inference.InferenceResults;
1515
import org.elasticsearch.inference.InferenceServiceResults;
1616
import org.elasticsearch.xcontent.ToXContent;
17+
import org.elasticsearch.xpack.core.inference.DequeUtils;
1718

1819
import java.io.IOException;
1920
import java.util.Collections;
@@ -22,13 +23,13 @@
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.concurrent.Flow;
26+
import java.util.concurrent.LinkedBlockingDeque;
27+
import java.util.concurrent.atomic.AtomicBoolean;
2528

2629
/**
2730
* Chat Completion results that only contain a Flow.Publisher.
2831
*/
29-
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher)
30-
implements
31-
InferenceServiceResults {
32+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {
3233

3334
public static final String NAME = "chat_completion_chunk";
3435
public static final String MODEL_FIELD = "model";
@@ -51,6 +52,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
5152
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5253
public static final String TYPE_FIELD = "type";
5354

55+
/**
56+
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
57+
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
58+
*/
59+
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
60+
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
61+
AtomicBoolean onComplete = new AtomicBoolean();
62+
this.publisher = downstream -> {
63+
publisher.subscribe(new Flow.Subscriber<>() {
64+
@Override
65+
public void onSubscribe(Flow.Subscription subscription) {
66+
downstream.onSubscribe(new Flow.Subscription() {
67+
@Override
68+
public void request(long n) {
69+
var nextItem = buffer.poll();
70+
if (nextItem != null) {
71+
downstream.onNext(new Results(DequeUtils.of(nextItem)));
72+
} else if (onComplete.get()) {
73+
downstream.onComplete();
74+
} else {
75+
subscription.request(n);
76+
}
77+
}
78+
79+
@Override
80+
public void cancel() {
81+
subscription.cancel();
82+
}
83+
});
84+
}
85+
86+
@Override
87+
public void onNext(Results item) {
88+
var chunks = item.chunks();
89+
var firstItem = chunks.poll();
90+
chunks.forEach(buffer::offer);
91+
downstream.onNext(new Results(DequeUtils.of(firstItem)));
92+
}
93+
94+
@Override
95+
public void onError(Throwable throwable) {
96+
downstream.onError(throwable);
97+
}
98+
99+
@Override
100+
public void onComplete() {
101+
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
102+
if (onComplete.compareAndSet(false, true)) {
103+
if (buffer.isEmpty()) {
104+
downstream.onComplete();
105+
}
106+
}
107+
}
108+
});
109+
};
110+
}
111+
54112
@Override
55113
public boolean isStreaming() {
56114
return true;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
import java.util.ArrayDeque;
1919
import java.util.Deque;
2020
import java.util.List;
21+
import java.util.concurrent.Flow;
22+
import java.util.concurrent.atomic.AtomicBoolean;
23+
import java.util.concurrent.atomic.AtomicInteger;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
import java.util.function.Supplier;
26+
27+
import static org.mockito.ArgumentMatchers.any;
28+
import static org.mockito.Mockito.spy;
29+
import static org.mockito.Mockito.times;
30+
import static org.mockito.Mockito.verify;
2131

2232
public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
2333

@@ -195,4 +205,99 @@ public void testToolCallToXContentChunked() throws IOException {
195205
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
196206
}
197207

208+
public void testBufferedPublishing() {
209+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
210+
results.offer(randomChatCompletionChunk());
211+
results.offer(randomChatCompletionChunk());
212+
var completed = new AtomicBoolean();
213+
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
214+
downstream.onSubscribe(new Flow.Subscription() {
215+
@Override
216+
public void request(long n) {
217+
if (completed.compareAndSet(false, true)) {
218+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
219+
} else {
220+
downstream.onComplete();
221+
}
222+
}
223+
224+
@Override
225+
public void cancel() {
226+
fail("Cancel should never be called.");
227+
}
228+
});
229+
});
230+
231+
AtomicInteger counter = new AtomicInteger(0);
232+
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
233+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
234+
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
235+
@Override
236+
public void onSubscribe(Flow.Subscription subscription) {
237+
if (upstream.compareAndSet(null, subscription) == false) {
238+
fail("Upstream already set?!");
239+
}
240+
subscription.request(1);
241+
}
242+
243+
@Override
244+
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
245+
assertNotNull(item);
246+
counter.incrementAndGet();
247+
var sub = upstream.get();
248+
if (sub != null) {
249+
sub.request(1);
250+
} else {
251+
fail("Upstream not yet set?!");
252+
}
253+
}
254+
255+
@Override
256+
public void onError(Throwable throwable) {
257+
fail(throwable);
258+
}
259+
260+
@Override
261+
public void onComplete() {}
262+
}
263+
);
264+
streamingResults.publisher().subscribe(subscriber);
265+
verify(subscriber, times(2)).onNext(any());
266+
}
267+
268+
private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() {
269+
Supplier<String> randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5);
270+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
271+
randomAlphanumericOfLength(5),
272+
randomBoolean() ? null : randomList(randomInt(5), () -> {
273+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
274+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
275+
randomOptionalString.get(),
276+
randomOptionalString.get(),
277+
randomOptionalString.get(),
278+
randomBoolean() ? null : randomList(randomInt(5), () -> {
279+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
280+
randomInt(5),
281+
randomOptionalString.get(),
282+
randomBoolean()
283+
? null
284+
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
285+
randomOptionalString.get(),
286+
randomOptionalString.get()
287+
),
288+
randomOptionalString.get()
289+
);
290+
})
291+
),
292+
randomOptionalString.get(),
293+
randomInt(5)
294+
);
295+
}),
296+
randomAlphanumericOfLength(5),
297+
randomAlphanumericOfLength(5),
298+
randomBoolean()
299+
? null
300+
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5))
301+
);
302+
}
198303
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.rest.RestStatus;
3434
import org.elasticsearch.xcontent.ToXContentObject;
3535
import org.elasticsearch.xcontent.XContentBuilder;
36+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3637
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3738
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3839

@@ -205,21 +206,23 @@ public void cancel() {}
205206
"object": "chat.completion.chunk"
206207
}
207208
*/
208-
private ChunkedToXContent unifiedCompletionChunk(String delta) {
209-
return params -> Iterators.concat(
210-
ChunkedToXContentHelper.startObject(),
211-
ChunkedToXContentHelper.field("id", "id"),
212-
ChunkedToXContentHelper.startArray("choices"),
213-
ChunkedToXContentHelper.startObject(),
214-
ChunkedToXContentHelper.startObject("delta"),
215-
ChunkedToXContentHelper.field("content", delta),
216-
ChunkedToXContentHelper.endObject(),
217-
ChunkedToXContentHelper.field("index", 0),
218-
ChunkedToXContentHelper.endObject(),
219-
ChunkedToXContentHelper.endArray(),
220-
ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"),
221-
ChunkedToXContentHelper.field("object", "chat.completion.chunk"),
222-
ChunkedToXContentHelper.endObject()
209+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
210+
return new StreamingUnifiedChatCompletionResults.Results(
211+
DequeUtils.of(
212+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
213+
"id",
214+
List.of(
215+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
216+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
217+
null,
218+
0
219+
)
220+
),
221+
"gpt-4o-2024-08-06",
222+
"chat.completion.chunk",
223+
null
224+
)
225+
)
223226
);
224227
}
225228

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1312
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1413
import org.elasticsearch.xcontent.ConstructingObjectParser;
1514
import org.elasticsearch.xcontent.ParseField;
@@ -28,13 +27,14 @@
2827
import java.util.Deque;
2928
import java.util.Iterator;
3029
import java.util.List;
31-
import java.util.concurrent.LinkedBlockingDeque;
3230
import java.util.function.BiFunction;
3331

3432
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3533
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
3634

37-
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
35+
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
36+
Deque<ServerSentEvent>,
37+
StreamingUnifiedChatCompletionResults.Results> {
3838
public static final String FUNCTION_FIELD = "function";
3939
private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);
4040

@@ -60,22 +60,12 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
6060
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6161

6262
private final BiFunction<String, Exception, Exception> errorParser;
63-
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
6463
private volatile boolean previousEventWasError = false;
6564

6665
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6766
this.errorParser = errorParser;
6867
}
6968

70-
@Override
71-
protected void upstreamRequest(long n) {
72-
if (buffer.isEmpty()) {
73-
super.upstreamRequest(n);
74-
} else {
75-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
76-
}
77-
}
78-
7969
@Override
8070
protected void next(Deque<ServerSentEvent> item) throws Exception {
8171
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -101,15 +91,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
10191

10292
if (results.isEmpty()) {
10393
upstream().request(1);
104-
} else if (results.size() == 1) {
105-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10694
} else {
107-
// results > 1, but openai spec only wants 1 chunk per SSE event
108-
var firstItem = singleItem(results.poll());
109-
while (results.isEmpty() == false) {
110-
buffer.offer(results.poll());
111-
}
112-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
95+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
11396
}
11497
}
11598

@@ -302,12 +285,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
302285
}
303286
}
304287
}
305-
306-
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
307-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
308-
) {
309-
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
310-
deque.offer(result);
311-
return deque;
312-
}
313288
}

0 commit comments

Comments
 (0)