Skip to content

Commit b108e39

Browse files
authored
[ML] Append all data to Chat Completion buffer (elastic#127658)
Moved the Chat Completion buffer into the StreamingUnifiedChatCompletionResults so that all Chat Completion responses can benefit from it. Chat Completions is meant to adhere to OpenAI as much as possible, and OpenAI only sends one response chunk at a time. All implementations of Chat Completions will now buffer. This fixes a bug where more than two chunks in a single item would be dropped, instead they are all added to the buffer. This fixes a bug where onComplete would omit trailing items in the buffer.
1 parent 542df25 commit b108e39

File tree

5 files changed

+155
-61
lines changed

5 files changed

+155
-61
lines changed

docs/changelog/127658.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127658
2+
summary: Append all data to Chat Completion buffer
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
1717
import org.elasticsearch.inference.InferenceServiceResults;
1818
import org.elasticsearch.xcontent.ToXContent;
19+
import org.elasticsearch.xpack.core.inference.DequeUtils;
1920

2021
import java.io.IOException;
2122
import java.util.Collections;
2223
import java.util.Deque;
2324
import java.util.Iterator;
2425
import java.util.List;
2526
import java.util.concurrent.Flow;
27+
import java.util.concurrent.LinkedBlockingDeque;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2629

2730
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;
2831
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
@@ -32,9 +35,7 @@
3235
/**
3336
* Chat Completion results that only contain a Flow.Publisher.
3437
*/
35-
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
36-
implements
37-
InferenceServiceResults {
38+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {
3839

3940
public static final String NAME = "chat_completion_chunk";
4041
public static final String MODEL_FIELD = "model";
@@ -57,6 +58,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Inf
5758
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5859
public static final String TYPE_FIELD = "type";
5960

61+
/**
62+
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
63+
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
64+
*/
65+
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
66+
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
67+
AtomicBoolean onComplete = new AtomicBoolean();
68+
this.publisher = downstream -> {
69+
publisher.subscribe(new Flow.Subscriber<>() {
70+
@Override
71+
public void onSubscribe(Flow.Subscription subscription) {
72+
downstream.onSubscribe(new Flow.Subscription() {
73+
@Override
74+
public void request(long n) {
75+
var nextItem = buffer.poll();
76+
if (nextItem != null) {
77+
downstream.onNext(new Results(DequeUtils.of(nextItem)));
78+
} else if (onComplete.get()) {
79+
downstream.onComplete();
80+
} else {
81+
subscription.request(n);
82+
}
83+
}
84+
85+
@Override
86+
public void cancel() {
87+
subscription.cancel();
88+
}
89+
});
90+
}
91+
92+
@Override
93+
public void onNext(Results item) {
94+
var chunks = item.chunks();
95+
var firstItem = chunks.poll();
96+
chunks.forEach(buffer::offer);
97+
downstream.onNext(new Results(DequeUtils.of(firstItem)));
98+
}
99+
100+
@Override
101+
public void onError(Throwable throwable) {
102+
downstream.onError(throwable);
103+
}
104+
105+
@Override
106+
public void onComplete() {
107+
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
108+
if (onComplete.compareAndSet(false, true)) {
109+
if (buffer.isEmpty()) {
110+
downstream.onComplete();
111+
}
112+
}
113+
}
114+
});
115+
};
116+
}
117+
60118
@Override
61119
public boolean isStreaming() {
62120
return true;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,17 @@
1919
import java.util.ArrayDeque;
2020
import java.util.Deque;
2121
import java.util.List;
22+
import java.util.concurrent.Flow;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicInteger;
25+
import java.util.concurrent.atomic.AtomicReference;
2226
import java.util.function.Supplier;
2327

28+
import static org.mockito.ArgumentMatchers.any;
29+
import static org.mockito.Mockito.spy;
30+
import static org.mockito.Mockito.times;
31+
import static org.mockito.Mockito.verify;
32+
2433
public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
2534
StreamingUnifiedChatCompletionResults.Results> {
2635

@@ -198,6 +207,66 @@ public void testToolCallToXContentChunked() throws IOException {
198207
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
199208
}
200209

210+
public void testBufferedPublishing() {
211+
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
212+
results.offer(randomChatCompletionChunk());
213+
results.offer(randomChatCompletionChunk());
214+
var completed = new AtomicBoolean();
215+
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
216+
downstream.onSubscribe(new Flow.Subscription() {
217+
@Override
218+
public void request(long n) {
219+
if (completed.compareAndSet(false, true)) {
220+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
221+
} else {
222+
downstream.onComplete();
223+
}
224+
}
225+
226+
@Override
227+
public void cancel() {
228+
fail("Cancel should never be called.");
229+
}
230+
});
231+
});
232+
233+
AtomicInteger counter = new AtomicInteger(0);
234+
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
235+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
236+
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
237+
@Override
238+
public void onSubscribe(Flow.Subscription subscription) {
239+
if (upstream.compareAndSet(null, subscription) == false) {
240+
fail("Upstream already set?!");
241+
}
242+
subscription.request(1);
243+
}
244+
245+
@Override
246+
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
247+
assertNotNull(item);
248+
counter.incrementAndGet();
249+
var sub = upstream.get();
250+
if (sub != null) {
251+
sub.request(1);
252+
} else {
253+
fail("Upstream not yet set?!");
254+
}
255+
}
256+
257+
@Override
258+
public void onError(Throwable throwable) {
259+
fail(throwable);
260+
}
261+
262+
@Override
263+
public void onComplete() {}
264+
}
265+
);
266+
streamingResults.publisher().subscribe(subscriber);
267+
verify(subscriber, times(2)).onNext(any());
268+
}
269+
201270
@Override
202271
protected Writeable.Reader<StreamingUnifiedChatCompletionResults.Results> instanceReader() {
203272
return StreamingUnifiedChatCompletionResults.Results::new;

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xcontent.ToXContent;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xpack.core.inference.DequeUtils;
3738
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3839
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
3940
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
@@ -256,37 +257,24 @@ public void cancel() {}
256257
"object": "chat.completion.chunk"
257258
}
258259
*/
259-
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
260-
return new InferenceServiceResults.Result() {
261-
@Override
262-
public String getWriteableName() {
263-
return "test_unifiedCompletionChunk";
264-
}
265-
266-
@Override
267-
public void writeTo(StreamOutput out) throws IOException {
268-
out.writeString(delta);
269-
}
270-
271-
@Override
272-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
273-
return ChunkedToXContentHelper.chunk(
274-
(b, p) -> b.startObject()
275-
.field("id", "id")
276-
.startArray("choices")
277-
.startObject()
278-
.startObject("delta")
279-
.field("content", delta)
280-
.endObject()
281-
.field("index", 0)
282-
.endObject()
283-
.endArray()
284-
.field("model", "gpt-4o-2024-08-06")
285-
.field("object", "chat.completion.chunk")
286-
.endObject()
287-
);
288-
}
289-
};
260+
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
261+
return new StreamingUnifiedChatCompletionResults.Results(
262+
DequeUtils.of(
263+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
264+
"id",
265+
List.of(
266+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
267+
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
268+
null,
269+
0
270+
)
271+
),
272+
"gpt-4o-2024-08-06",
273+
"chat.completion.chunk",
274+
null
275+
)
276+
)
277+
);
290278
}
291279

292280
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.Deque;
2727
import java.util.Iterator;
2828
import java.util.List;
29-
import java.util.concurrent.LinkedBlockingDeque;
3029
import java.util.function.BiFunction;
3130

3231
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
@@ -60,21 +59,11 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
6059
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6160

6261
private final BiFunction<String, Exception, Exception> errorParser;
63-
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
6462

6563
public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
6664
this.errorParser = errorParser;
6765
}
6866

69-
@Override
70-
protected void upstreamRequest(long n) {
71-
if (buffer.isEmpty()) {
72-
super.upstreamRequest(n);
73-
} else {
74-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
75-
}
76-
}
77-
7867
@Override
7968
protected void next(Deque<ServerSentEvent> item) throws Exception {
8069
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -96,15 +85,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
9685

9786
if (results.isEmpty()) {
9887
upstream().request(1);
99-
} else if (results.size() == 1) {
100-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10188
} else {
102-
// results > 1, but openai spec only wants 1 chunk per SSE event
103-
var firstItem = singleItem(results.poll());
104-
while (results.isEmpty() == false) {
105-
buffer.offer(results.poll());
106-
}
107-
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
89+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
10890
}
10991
}
11092

@@ -297,12 +279,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
297279
}
298280
}
299281
}
300-
301-
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
302-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
303-
) {
304-
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
305-
deque.offer(result);
306-
return deque;
307-
}
308282
}

0 commit comments

Comments
 (0)