Skip to content

Commit 6e3db61

Browse files
author
Max Hniebergall
committed
Finalize response format
1 parent 0ba212f commit 6e3db61

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,7 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
8989
public record Results(Deque<ChatCompletionChunk> chunks) implements ChunkedToXContent {
9090
@Override
9191
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
92-
return Iterators.concat(
93-
ChunkedToXContentHelper.startObject(),
94-
ChunkedToXContentHelper.startArray(NAME),
95-
Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)),
96-
ChunkedToXContentHelper.endArray(),
97-
ChunkedToXContentHelper.endObject()
98-
);
92+
return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)));
9993
}
10094
}
10195

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void request(long n) {
5151
if (isClosed.get()) {
5252
downstream.onComplete();
5353
} else if (upstream != null) {
54-
upstream.request(n);
54+
onRequest(n);
5555
} else {
5656
pendingRequests.accumulateAndGet(n, Long::sum);
5757
}
@@ -67,6 +67,13 @@ public void cancel() {
6767
};
6868
}
6969

70+
/**
71+
* Guaranteed to be called when the upstream is set and this processor had not been closed.
72+
*/
73+
protected void onRequest(long n) {
74+
upstream.request(n);
75+
}
76+
7077
protected void onCancel() {}
7178

7279
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.Deque;
2929
import java.util.Iterator;
3030
import java.util.List;
31+
import java.util.concurrent.LinkedBlockingDeque;
3132

3233
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3334
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@@ -58,6 +59,17 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
5859
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
5960
public static final String TOTAL_TOKENS_FIELD = "total_tokens";
6061

62+
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
63+
64+
@Override
65+
protected void onRequest(long n) {
66+
if (buffer.isEmpty()) {
67+
super.onRequest(n);
68+
} else {
69+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
70+
}
71+
}
72+
6173
@Override
6274
protected void next(Deque<ServerSentEvent> item) throws Exception {
6375
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
@@ -77,8 +89,16 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
7789

7890
if (results.isEmpty()) {
7991
upstream().request(1);
80-
} else {
92+
}
93+
if (results.size() == 1) {
8194
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
95+
} else {
96+
// results > 1, but openai spec only wants 1 chunk per SSE event
97+
var firstItem = singleItem(results.poll());
98+
while (results.isEmpty() == false) {
99+
buffer.offer(results.poll());
100+
}
101+
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
82102
}
83103
}
84104

@@ -270,4 +290,12 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
270290
}
271291
}
272292
}
293+
294+
private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
295+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
296+
) {
297+
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(2);
298+
deque.offer(result);
299+
return deque;
300+
}
273301
}

0 commit comments

Comments
 (0)