Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127658.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127658
summary: Append all data to Chat Completion buffer
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;

public final class DequeUtils {

private DequeUtils() {
// util functions only
}

public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
}

public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
if (thisDeque.size() != otherDeque.size()) {
return false;
}
var thisIter = thisDeque.iterator();
var otherIter = otherDeque.iterator();
while (thisIter.hasNext() && otherIter.hasNext()) {
if (thisIter.next().equals(otherIter.next()) == false) {
return false;
}
}
return true;
}

public static int dequeHashCode(Deque<?> deque) {
if (deque == null) {
return 0;
}
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
}

public static <T> Deque<T> of(T elem) {
var deque = new ArrayDeque<T>(1);
deque.offer(elem);
return deque;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.DequeUtils;

import java.io.IOException;
import java.util.Collections;
Expand All @@ -23,15 +24,15 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;

/**
* Chat Completion results that only contain a Flow.Publisher.
*/
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher)
implements
InferenceServiceResults {
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) implements InferenceServiceResults {

public static final String NAME = "chat_completion_chunk";
public static final String MODEL_FIELD = "model";
Expand All @@ -54,6 +55,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
public static final String PROMPT_TOKENS_FIELD = "prompt_tokens";
public static final String TYPE_FIELD = "type";

/**
* OpenAI Spec only returns one result at a time, and Chat Completion adheres to that spec as much as possible.
* So we will insert a buffer in between the upstream data and the downstream client so that we only send one request at a time.
*/
public StreamingUnifiedChatCompletionResults(Flow.Publisher<Results> publisher) {
Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
AtomicBoolean onComplete = new AtomicBoolean();
this.publisher = downstream -> {
publisher.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
var nextItem = buffer.poll();
if (nextItem != null) {
downstream.onNext(new Results(DequeUtils.of(nextItem)));
} else if (onComplete.get()) {
downstream.onComplete();
} else {
subscription.request(n);
}
}

@Override
public void cancel() {
subscription.cancel();
}
});
}

@Override
public void onNext(Results item) {
var chunks = item.chunks();
var firstItem = chunks.poll();
chunks.forEach(buffer::offer);
downstream.onNext(new Results(DequeUtils.of(firstItem)));
}

@Override
public void onError(Throwable throwable) {
downstream.onError(throwable);
}

@Override
public void onComplete() {
// only complete if the buffer is empty, so that the client has a chance to drain the buffer
if (onComplete.compareAndSet(false, true)) {
if (buffer.isEmpty()) {
downstream.onComplete();
}
}
}
});
};
}

@Override
public boolean isStreaming() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {

Expand Down Expand Up @@ -195,4 +205,99 @@ public void testToolCallToXContentChunked() throws IOException {
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
}

public void testBufferedPublishing() {
var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
results.offer(randomChatCompletionChunk());
results.offer(randomChatCompletionChunk());
var completed = new AtomicBoolean();
var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> {
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
if (completed.compareAndSet(false, true)) {
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results));
} else {
downstream.onComplete();
}
}

@Override
public void cancel() {
fail("Cancel should never be called.");
}
});
});

AtomicInteger counter = new AtomicInteger(0);
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = spy(
new Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
if (upstream.compareAndSet(null, subscription) == false) {
fail("Upstream already set?!");
}
subscription.request(1);
}

@Override
public void onNext(StreamingUnifiedChatCompletionResults.Results item) {
assertNotNull(item);
counter.incrementAndGet();
var sub = upstream.get();
if (sub != null) {
sub.request(1);
} else {
fail("Upstream not yet set?!");
}
}

@Override
public void onError(Throwable throwable) {
fail(throwable);
}

@Override
public void onComplete() {}
}
);
streamingResults.publisher().subscribe(subscriber);
verify(subscriber, times(2)).onNext(any());
}

private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() {
Supplier<String> randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5);
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
randomAlphanumericOfLength(5),
randomBoolean() ? null : randomList(randomInt(5), () -> {
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
randomOptionalString.get(),
randomOptionalString.get(),
randomOptionalString.get(),
randomBoolean() ? null : randomList(randomInt(5), () -> {
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
randomInt(5),
randomOptionalString.get(),
randomBoolean()
? null
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
randomOptionalString.get(),
randomOptionalString.get()
),
randomOptionalString.get()
);
})
),
randomOptionalString.get(),
randomInt(5)
);
}),
randomAlphanumericOfLength(5),
randomAlphanumericOfLength(5),
randomBoolean()
? null
: new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.DequeUtils;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;

Expand Down Expand Up @@ -198,21 +199,23 @@ public void cancel() {}
"object": "chat.completion.chunk"
}
*/
private ChunkedToXContent unifiedCompletionChunk(String delta) {
return params -> ChunkedToXContentHelper.chunk(
(b, p) -> b.startObject()
.field("id", "id")
.startArray("choices")
.startObject()
.startObject("delta")
.field("content", delta)
.endObject()
.field("index", 0)
.endObject()
.endArray()
.field("model", "gpt-4o-2024-08-06")
.field("object", "chat.completion.chunk")
.endObject()
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
return new StreamingUnifiedChatCompletionResults.Results(
DequeUtils.of(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
"id",
List.of(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
null,
0
)
),
"gpt-4o-2024-08-06",
"chat.completion.chunk",
null
)
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -28,13 +27,14 @@
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.BiFunction;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
Deque<ServerSentEvent>,
StreamingUnifiedChatCompletionResults.Results> {
public static final String FUNCTION_FIELD = "function";
private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);

Expand All @@ -60,22 +60,12 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<S
public static final String TOTAL_TOKENS_FIELD = "total_tokens";

private final BiFunction<String, Exception, Exception> errorParser;
private final Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> buffer = new LinkedBlockingDeque<>();
private volatile boolean previousEventWasError = false;

public OpenAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
this.errorParser = errorParser;
}

@Override
protected void upstreamRequest(long n) {
if (buffer.isEmpty()) {
super.upstreamRequest(n);
} else {
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll())));
}
}

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
Expand All @@ -101,15 +91,8 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {

if (results.isEmpty()) {
upstream().request(1);
} else if (results.size() == 1) {
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
} else {
// results > 1, but openai spec only wants 1 chunk per SSE event
var firstItem = singleItem(results.poll());
while (results.isEmpty() == false) {
buffer.offer(results.poll());
}
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem));
downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
}
}

Expand Down Expand Up @@ -302,12 +285,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa
}
}
}

private Deque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> singleItem(
StreamingUnifiedChatCompletionResults.ChatCompletionChunk result
) {
var deque = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(1);
deque.offer(result);
return deque;
}
}