diff --git a/docs/changelog/127658.yaml b/docs/changelog/127658.yaml new file mode 100644 index 0000000000000..1a8d5ced7c8b6 --- /dev/null +++ b/docs/changelog/127658.yaml @@ -0,0 +1,5 @@ +pr: 127658 +summary: Append all data to Chat Completion buffer +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java new file mode 100644 index 0000000000000..060f66566109a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; + +public final class DequeUtils { + + private DequeUtils() { + // util functions only + } + + public static Deque readDeque(StreamInput in, Writeable.Reader reader) throws IOException { + return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in)))); + } + + public static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { + if (thisDeque.size() != otherDeque.size()) { + return false; + } + var thisIter = thisDeque.iterator(); + var otherIter = otherDeque.iterator(); + while (thisIter.hasNext() && otherIter.hasNext()) { + if (thisIter.next().equals(otherIter.next()) == false) { + return false; + } + } + return true; + } + + public static int dequeHashCode(Deque deque) { + if (deque == null) { + return 0; + } + return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); + } + + public static Deque of(T elem) { + var deque = new ArrayDeque(1); + deque.offer(elem); + return deque; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 515c366b5ed13..a9f27b9451487 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.core.inference.DequeUtils; import java.io.IOException; import java.util.Collections; @@ -23,15 +24,15 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Flow; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk; /** * Chat Completion results that only contain a Flow.Publisher. */ -public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) - implements - InferenceServiceResults { +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) implements InferenceServiceResults { public static final String NAME = "chat_completion_chunk"; public static final String MODEL_FIELD = "model"; @@ -54,6 +55,63 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) { + Deque buffer = new LinkedBlockingDeque<>(); + AtomicBoolean onComplete = new AtomicBoolean(); + this.publisher = downstream -> { + publisher.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + downstream.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + var nextItem = buffer.poll(); + if (nextItem != null) { + downstream.onNext(new Results(DequeUtils.of(nextItem))); + } else if (onComplete.get()) { + downstream.onComplete(); + } else { + subscription.request(n); + } + } + + @Override + public void cancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Results item) { + var chunks = item.chunks(); + var firstItem = chunks.poll(); + chunks.forEach(buffer::offer); + downstream.onNext(new Results(DequeUtils.of(firstItem))); + } + + @Override + public void onError(Throwable throwable) { + downstream.onError(throwable); + } + + @Override + public void onComplete() { + // only complete if the buffer is empty, so that the client has a chance to drain the buffer + if (onComplete.compareAndSet(false, true)) { + if (buffer.isEmpty()) { + downstream.onComplete(); + } + } + } + }); + }; + } + @Override public boolean isStreaming() { return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index a8f569dbef9d1..7d364d84469ba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -18,6 +18,16 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { @@ -195,4 +205,99 @@ public void testToolCallToXContentChunked() throws IOException { assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); } + public void testBufferedPublishing() { + var results = new ArrayDeque(); + results.offer(randomChatCompletionChunk()); + results.offer(randomChatCompletionChunk()); + var completed = new AtomicBoolean(); + var streamingResults = new StreamingUnifiedChatCompletionResults(downstream -> { + downstream.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (completed.compareAndSet(false, true)) { + downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + downstream.onComplete(); + } + } + + @Override + public void cancel() { + fail("Cancel should never be called."); + } + }); + }); + + AtomicInteger counter = new AtomicInteger(0); + AtomicReference upstream = new AtomicReference<>(null); + Flow.Subscriber subscriber = spy( + new Flow.Subscriber() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (upstream.compareAndSet(null, subscription) == false) { + fail("Upstream already set?!"); + } + subscription.request(1); + } + + @Override + public void onNext(StreamingUnifiedChatCompletionResults.Results item) { + assertNotNull(item); + counter.incrementAndGet(); + var sub = upstream.get(); + if (sub != null) { + sub.request(1); + } else { + fail("Upstream not yet set?!"); + } + } + + @Override + public void onError(Throwable throwable) { + fail(throwable); + } + + @Override + public void onComplete() {} + } + ); + streamingResults.publisher().subscribe(subscriber); + verify(subscriber, times(2)).onNext(any()); + } + + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() { + Supplier randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + randomAlphanumericOfLength(5), + randomBoolean() ? null : randomList(randomInt(5), () -> { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + randomOptionalString.get(), + randomOptionalString.get(), + randomOptionalString.get(), + randomBoolean() ? null : randomList(randomInt(5), () -> { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + randomInt(5), + randomOptionalString.get(), + randomBoolean() + ? null + : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + randomOptionalString.get(), + randomOptionalString.get() + ), + randomOptionalString.get() + ); + }) + ), + randomOptionalString.get(), + randomInt(5) + ); + }), + randomAlphanumericOfLength(5), + randomAlphanumericOfLength(5), + randomBoolean() + ? null + : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5)) + ); + } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 9355fa7d0ad48..03006feee0b81 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -32,6 +32,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.DequeUtils; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; @@ -198,21 +199,23 @@ public void cancel() {} "object": "chat.completion.chunk" } */ - private ChunkedToXContent unifiedCompletionChunk(String delta) { - return params -> ChunkedToXContentHelper.chunk( - (b, p) -> b.startObject() - .field("id", "id") - .startArray("choices") - .startObject() - .startObject("delta") - .field("content", delta) - .endObject() - .field("index", 0) - .endObject() - .endArray() - .field("model", "gpt-4o-2024-08-06") - .field("object", "chat.completion.chunk") - .endObject() + private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) { + return new StreamingUnifiedChatCompletionResults.Results( + DequeUtils.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "id", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null), + null, + 0 + ) + ), + "gpt-4o-2024-08-06", + "chat.completion.chunk", + null + ) + ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index bfd4456279a8a..9cb3a47681d21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -28,13 +27,14 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; -import java.util.concurrent.LinkedBlockingDeque; import java.util.function.BiFunction; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< + Deque, + StreamingUnifiedChatCompletionResults.Results> { public static final String FUNCTION_FIELD = "function"; private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); @@ -60,22 +60,12 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor errorParser; - private final Deque buffer = new LinkedBlockingDeque<>(); private volatile boolean previousEventWasError = false; public OpenAiUnifiedStreamingProcessor(BiFunction errorParser) { this.errorParser = errorParser; } - @Override - protected void upstreamRequest(long n) { - if (buffer.isEmpty()) { - super.upstreamRequest(n); - } else { - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); - } - } - @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -101,15 +91,8 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } else if (results.size() == 1) { - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } else { - // results > 1, but openai spec only wants 1 chunk per SSE event - var firstItem = singleItem(results.poll()); - while (results.isEmpty() == false) { - buffer.offer(results.poll()); - } - downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } } @@ -302,12 +285,4 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa } } } - - private Deque singleItem( - StreamingUnifiedChatCompletionResults.ChatCompletionChunk result - ) { - var deque = new ArrayDeque(1); - deque.offer(result); - return deque; - } }