From beb5e38fea15d9f5d176829a2672f789d85b5f0f Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 13 Feb 2025 07:47:39 -0500 Subject: [PATCH 1/3] [ML] Make Streaming Results Writeable Make streaming elements extend Writeable and create StreamInput constructors so we can publish elements across nodes using the transport layer. Additional notes: - Moved optional methods into the InferenceServiceResults interface and default them - StreamingUnifiedChatCompletionResults elements are now all records --- .../inference/InferenceServiceResults.java | 35 ++- .../inference/action/InferenceAction.java | 7 +- .../StreamingChatCompletionResults.java | 101 ++++--- ...StreamingUnifiedChatCompletionResults.java | 256 ++++++++++-------- .../StreamingChatCompletionResultsTests.java | 42 +++ ...mingUnifiedChatCompletionResultsTests.java | 68 ++++- ...stStreamingCompletionServiceExtension.java | 79 ++++-- ...rverSentEventsRestActionListenerTests.java | 28 +- .../InferenceNamedWriteablesProvider.java | 16 ++ .../action/BaseTransportInferenceAction.java | 12 +- ...sportUnifiedCompletionInferenceAction.java | 5 +- .../amazonbedrock/AmazonBedrockClient.java | 5 +- .../AmazonBedrockInferenceClient.java | 5 +- .../openai/OpenAiStreamingProcessor.java | 4 +- .../OpenAiUnifiedStreamingProcessor.java | 5 +- .../AmazonBedrockChatCompletionRequest.java | 6 +- .../BaseTransportInferenceActionTestCase.java | 4 +- .../OpenAiUnifiedStreamingProcessorTests.java | 122 ++++----- 18 files changed, 529 insertions(+), 271 deletions(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 34c8ffcb82d09..9b55e9ce54a86 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -10,8 +10,12 @@ package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.xcontent.ToXContent; +import java.io.IOException; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.Flow; @@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte * *

For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.

*/ - List transformToCoordinationFormat(); + default List transformToCoordinationFormat() { + throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented"); + } /** * Transform the result to match the format required for versions prior to * {@link org.elasticsearch.TransportVersions#V_8_12_0} */ - List transformToLegacyFormat(); + default List transformToLegacyFormat() { + throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented"); + } /** * Convert the result to a map to aid with test assertions */ - Map asMap(); + default Map asMap() { + throw new UnsupportedOperationException("asMap() is not implemented"); + } + + default String getWriteableName() { + assert isStreaming() : "This must be implemented when isStreaming() == false"; + throw new UnsupportedOperationException("This must be implemented when isStreaming() == false"); + } + + default void writeTo(StreamOutput out) throws IOException { + assert isStreaming() : "This must be implemented when isStreaming() == false"; + throw new UnsupportedOperationException("This must be implemented when isStreaming() == false"); + } + + default Iterator toXContentChunked(ToXContent.Params params) { + assert isStreaming() : "This must be implemented when isStreaming() == false"; + throw new UnsupportedOperationException("This must be implemented when isStreaming() == false"); + } /** * Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload. @@ -52,8 +77,10 @@ default boolean isStreaming() { * When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher. * Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks. */ - default Flow.Publisher publisher() { + default Flow.Publisher publisher() { assert isStreaming() == false : "This must be implemented when isStreaming() == true"; throw new UnsupportedOperationException("This must be implemented when isStreaming() == true"); } + + interface Result extends NamedWriteable, ChunkedToXContent {} } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index f2b2c563d7519..dc177795af76a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.core.TimeValue; @@ -342,7 +341,7 @@ public static class Response extends ActionResponse implements ChunkedToXContent private final InferenceServiceResults results; private final boolean isStreaming; - private final Flow.Publisher publisher; + private final Flow.Publisher publisher; public Response(InferenceServiceResults results) { this.results = results; @@ -350,7 +349,7 @@ public Response(InferenceServiceResults results) { this.publisher = null; } - public Response(InferenceServiceResults results, Flow.Publisher publisher) { + public Response(InferenceServiceResults results, Flow.Publisher publisher) { this.results = results; this.isStreaming = true; this.publisher = publisher; @@ -434,7 +433,7 @@ public boolean isStreaming() { * When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription. * If the RestResponse is closed, it will cancel the subscription. */ - public Flow.Publisher publisher() { + public Flow.Publisher publisher() { assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null"; return publisher; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index 59778b83953ff..7926590ed5326 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -8,18 +8,19 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; +import java.util.ArrayDeque; import java.util.Deque; import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.Objects; import java.util.concurrent.Flow; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; @@ -27,44 +28,26 @@ /** * Chat Completion results that only contain a Flow.Publisher. */ -public record StreamingChatCompletionResults(Flow.Publisher publisher) implements InferenceServiceResults { +public record StreamingChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { @Override public boolean isStreaming() { return true; } - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Not implemented"); - } + public record Results(Deque results) implements InferenceServiceResults.Result { + public static final String NAME = "streaming_chat_completion_results"; - @Override - public Map asMap() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public String getWriteableName() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new UnsupportedOperationException("Not implemented"); - } + public Results(StreamInput in) throws IOException { + this(deque(in)); + } - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - throw new UnsupportedOperationException("Not implemented"); - } + private static Deque deque(StreamInput in) throws IOException { + return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new Result(stream)))); + } - public record Results(Deque results) implements ChunkedToXContent { @Override public Iterator toXContentChunked(ToXContent.Params params) { return Iterators.concat( @@ -75,14 +58,66 @@ public Iterator toXContentChunked(ToXContent.Params params ChunkedToXContentHelper.endObject() ); } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(results); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Results other = (Results) o; + return dequeEquals(this.results, other.results()); + } + + private static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { + if (thisDeque.size() != otherDeque.size()) { + return false; + } + var thisIter = thisDeque.iterator(); + var otherIter = otherDeque.iterator(); + while (thisIter.hasNext() && otherIter.hasNext()) { + if (thisIter.next().equals(otherIter.next()) == false) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + return dequeHashCode(results); + } + + private static int dequeHashCode(Deque deque) { + if (deque == null) { + return 0; + } + return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); + } } - public record Result(String delta) implements ChunkedToXContent { + public record Result(String delta) implements ChunkedToXContent, Writeable { private static final String RESULT = "delta"; + private Result(StreamInput in) throws IOException { + this(in.readString()); + } + @Override public Iterator toXContentChunked(ToXContent.Params params) { return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field(RESULT, delta).endObject()); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(delta); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 515c366b5ed13..9c4cc6f2922b4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; +import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.concurrent.Flow; import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk; @@ -29,7 +30,7 @@ /** * Chat Completion results that only contain a Flow.Publisher. */ -public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) implements InferenceServiceResults { @@ -60,76 +61,83 @@ public boolean isStreaming() { } @Override - public List transformToCoordinationFormat() { + public Iterator toXContentChunked(ToXContent.Params params) { throw new UnsupportedOperationException("Not implemented"); } - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Not implemented"); - } + public record Results(Deque chunks) implements InferenceServiceResults.Result { + public static String NAME = "streaming_unified_chat_completion_results"; - @Override - public Map asMap() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public String getWriteableName() { - throw new UnsupportedOperationException("Not implemented"); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new UnsupportedOperationException("Not implemented"); - } + public Results(StreamInput in) throws IOException { + this(deque(in)); + } - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - throw new UnsupportedOperationException("Not implemented"); - } + private static Deque deque(StreamInput in) throws IOException { + return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new ChatCompletionChunk(stream)))); + } - public record Results(Deque chunks) implements ChunkedToXContent { @Override public Iterator toXContentChunked(ToXContent.Params params) { return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); } - } - public static class ChatCompletionChunk implements ChunkedToXContent { - private final String id; + @Override + public String getWriteableName() { + return NAME; + } - public String getId() { - return id; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(chunks, StreamOutput::writeWriteable); } - public List getChoices() { - return choices; + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Results results = (Results) o; + return dequeEquals(chunks, results.chunks()); } - public String getModel() { - return model; + private static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { + if (thisDeque.size() != otherDeque.size()) { + return false; + } + var thisIter = thisDeque.iterator(); + var otherIter = otherDeque.iterator(); + while (thisIter.hasNext() && otherIter.hasNext()) { + if (thisIter.next().equals(otherIter.next()) == false) { + return false; + } + } + return true; } - public String getObject() { - return object; + @Override + public int hashCode() { + return dequeHashCode(chunks); } - public Usage getUsage() { - return usage; + private static int dequeHashCode(Deque deque) { + if (deque == null) { + return 0; + } + return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); } + } - private final List choices; - private final String model; - private final String object; - private final ChatCompletionChunk.Usage usage; - - public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { - this.id = id; - this.choices = choices; - this.model = model; - this.object = object; - this.usage = usage; + public record ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) + implements + ChunkedToXContent, + Writeable { + + private ChatCompletionChunk(StreamInput in) throws IOException { + this( + in.readString(), + in.readOptionalCollectionAsList(Choice::new), + in.readString(), + in.readString(), + in.readOptional(Usage::new) + ); } @Override @@ -152,7 +160,23 @@ public Iterator toXContentChunked(ToXContent.Params params ); } - public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) implements ChunkedToXContentObject { + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeOptionalCollection(choices); + out.writeString(model); + out.writeString(object); + out.writeOptionalWriteable(usage); + } + + public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) + implements + ChunkedToXContentObject, + Writeable { + + private Choice(StreamInput in) throws IOException { + this(new Delta(in), in.readOptionalString(), in.readInt()); + } /* choices: Array<{ @@ -172,17 +196,22 @@ public Iterator toXContentChunked(ToXContent.Params params ); } - public static class Delta { - private final String content; - private final String refusal; - private final String role; - private List toolCalls; - - public Delta(String content, String refusal, String role, List toolCalls) { - this.content = content; - this.refusal = refusal; - this.role = role; - this.toolCalls = toolCalls; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeWriteable(delta); + out.writeOptionalString(finishReason); + out.writeInt(index); + } + + public record Delta(String content, String refusal, String role, List toolCalls) implements Writeable { + + private Delta(StreamInput in) throws IOException { + this( + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalCollectionAsList(ToolCall::new) + ); } /* @@ -214,49 +243,26 @@ public Iterator toXContentChunked(ToXContent.Params params } - public String getContent() { - return content; - } - - public String getRefusal() { - return refusal; - } - - public String getRole() { - return role; - } - - public List getToolCalls() { - return toolCalls; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(content); + out.writeOptionalString(refusal); + out.writeOptionalString(role); + out.writeOptionalCollection(toolCalls); } - public static class ToolCall implements ChunkedToXContentObject { - private final int index; - private final String id; - public ChatCompletionChunk.Choice.Delta.ToolCall.Function function; - private final String type; - - public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) { - this.index = index; - this.id = id; - this.function = function; - this.type = type; - } - - public int getIndex() { - return index; - } - - public String getId() { - return id; - } - - public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() { - return function; - } - - public String getType() { - return type; + public record ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) + implements + ChunkedToXContentObject, + Writeable { + + private ToolCall(StreamInput in) throws IOException { + this( + in.readInt(), + in.readOptionalString(), + in.readOptional(ChatCompletionChunk.Choice.Delta.ToolCall.Function::new), + in.readOptionalString() + ); } /* @@ -280,8 +286,8 @@ public Iterator toXContentChunked(ToXContent.Params params content = Iterators.concat( content, ChunkedToXContentHelper.startObject(FUNCTION_FIELD), - optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()), - optionalField(FUNCTION_NAME_FIELD, function.getName()), + optionalField(FUNCTION_ARGUMENTS_FIELD, function.arguments()), + optionalField(FUNCTION_NAME_FIELD, function.name()), ChunkedToXContentHelper.endObject() ); } @@ -294,28 +300,42 @@ public Iterator toXContentChunked(ToXContent.Params params return content; } - public static class Function { - private final String arguments; - private final String name; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(index); + out.writeOptionalString(id); + out.writeOptionalWriteable(function); + out.writeOptionalString(type); + } - public Function(String arguments, String name) { - this.arguments = arguments; - this.name = name; - } + public record Function(String arguments, String name) implements Writeable { - public String getArguments() { - return arguments; + private Function(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readOptionalString()); } - public String getName() { - return name; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(arguments); + out.writeOptionalString(name); } } } } } - public record Usage(int completionTokens, int promptTokens, int totalTokens) {} + public record Usage(int completionTokens, int promptTokens, int totalTokens) implements Writeable { + private Usage(StreamInput in) throws IOException { + this(in.readInt(), in.readInt(), in.readInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(completionTokens); + out.writeInt(promptTokens); + out.writeInt(totalTokens); + } + } private static Iterator optionalField(String name, String value) { if (value == null) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java new file mode 100644 index 0000000000000..08eeb49781bad --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.ArrayDeque; + +public class StreamingChatCompletionResultsTests extends AbstractWireSerializingTestCase< + StreamingChatCompletionResults.Results> { + @Override + protected Writeable.Reader instanceReader() { + return StreamingChatCompletionResults.Results::new; + } + + @Override + protected StreamingChatCompletionResults.Results createTestInstance() { + var results = new ArrayDeque(); + for(int i = 0; i < randomIntBetween(1, 10) ; i++) { + results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5))); + } + return new StreamingChatCompletionResults.Results(results); + } + + @Override + protected StreamingChatCompletionResults.Results mutateInstance(StreamingChatCompletionResults.Results instance) throws IOException { + var results = new ArrayDeque<>(instance.results()); + if(randomBoolean()) { + results.pop(); + } else { + results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5))); + } + return new StreamingChatCompletionResults.Results(results); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index a8f569dbef9d1..669ba2f881fe7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -10,7 +10,8 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; @@ -18,8 +19,10 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; +import java.util.function.Supplier; -public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { +public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase< + StreamingUnifiedChatCompletionResults.Results> { public void testResults_toXContentChunked() throws IOException { String expected = """ @@ -195,4 +198,65 @@ public void testToolCallToXContentChunked() throws IOException { assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); } + @Override + protected Writeable.Reader instanceReader() { + return StreamingUnifiedChatCompletionResults.Results::new; + } + + @Override + protected StreamingUnifiedChatCompletionResults.Results createTestInstance() { + var results = new ArrayDeque(); + for (int i = 0; i < randomIntBetween(1, 3); i++) { + results.offer(randomChatCompletionChunk()); + } + return new StreamingUnifiedChatCompletionResults.Results(results); + } + + private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() { + Supplier randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5); + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + randomAlphanumericOfLength(5), + randomBoolean() ? null : randomList(randomInt(5), () -> { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + randomOptionalString.get(), + randomOptionalString.get(), + randomOptionalString.get(), + randomBoolean() ? null : randomList(randomInt(5), () -> { + return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + randomInt(5), + randomOptionalString.get(), + randomBoolean() + ? null + : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + randomOptionalString.get(), + randomOptionalString.get() + ), + randomOptionalString.get() + ); + }) + ), + randomOptionalString.get(), + randomInt(5) + ); + }), + randomAlphanumericOfLength(5), + randomAlphanumericOfLength(5), + randomBoolean() + ? null + : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5)) + ); + } + + @Override + protected StreamingUnifiedChatCompletionResults.Results mutateInstance(StreamingUnifiedChatCompletionResults.Results instance) + throws IOException { + var results = new ArrayDeque<>(instance.chunks()); + if (randomBoolean()) { + results.pop(); + } else { + results.add(randomChatCompletionChunk()); + } + return new StreamingUnifiedChatCompletionResults.Results(results); // immutable + } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 9355fa7d0ad48..8c876e9947bba 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; @@ -30,6 +29,7 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; @@ -38,6 +38,7 @@ import java.io.IOException; import java.util.EnumSet; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -157,10 +158,31 @@ public void cancel() {} }); } - private ChunkedToXContent completionChunk(String delta) { - return params -> ChunkedToXContentHelper.chunk( - (b, p) -> b.startObject().startArray(COMPLETION).startObject().field("delta", delta).endObject().endArray().endObject() - ); + private InferenceServiceResults.Result completionChunk(String delta) { + return new InferenceServiceResults.Result() { + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.chunk( + (b, p) -> b.startObject() + .startArray(COMPLETION) + .startObject() + .field("delta", delta) + .endObject() + .endArray() + .endObject() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(delta); + } + + @Override + public String getWriteableName() { + return "test_completionChunk"; + } + }; } private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) { @@ -198,22 +220,37 @@ public void cancel() {} "object": "chat.completion.chunk" } */ - private ChunkedToXContent unifiedCompletionChunk(String delta) { - return params -> ChunkedToXContentHelper.chunk( - (b, p) -> b.startObject() - .field("id", "id") - .startArray("choices") - .startObject() - .startObject("delta") - .field("content", delta) - .endObject() - .field("index", 0) - .endObject() - .endArray() - .field("model", "gpt-4o-2024-08-06") - .field("object", "chat.completion.chunk") - .endObject() - ); + private InferenceServiceResults.Result unifiedCompletionChunk(String delta) { + return new InferenceServiceResults.Result() { + @Override + public String getWriteableName() { + return "test_unifiedCompletionChunk"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(delta); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.chunk( + (b, p) -> b.startObject() + .field("id", "id") + .startArray("choices") + .startObject() + .startObject("delta") + .field("content", delta) + .endObject() + .field("index", 0) + .endObject() + .endArray() + .field("model", "gpt-4o-2024-08-06") + .field("object", "chat.completion.chunk") + .endObject() + ); + } + }; } @Override diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java index 903961794b337..d9865e0d26337 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java @@ -33,7 +33,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.inference.InferenceResults; @@ -179,14 +178,14 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c } private static class StreamingInferenceServiceResults implements InferenceServiceResults { - private final Flow.Publisher publisher; + private final Flow.Publisher publisher; - private StreamingInferenceServiceResults(Flow.Publisher publisher) { + private StreamingInferenceServiceResults(Flow.Publisher publisher) { this.publisher = publisher; } @Override - public Flow.Publisher publisher() { + public Flow.Publisher publisher() { return publisher; } @@ -224,7 +223,7 @@ public Iterator toXContentChunked(ToXContent.Params params } } - private static class RandomPublisher implements Flow.Publisher { + private static class RandomPublisher implements Flow.Publisher { private final int requestCount; private final boolean withError; @@ -234,7 +233,7 @@ private RandomPublisher(int requestCount, boolean withError) { } @Override - public void subscribe(Flow.Subscriber subscriber) { + public void subscribe(Flow.Subscriber subscriber) { var resultCount = new AtomicInteger(requestCount); subscriber.onSubscribe(new Flow.Subscription() { @Override @@ -256,12 +255,25 @@ public void cancel() { } } - private static class RandomString implements ChunkedToXContent { + private record RandomString(String randomString) implements InferenceServiceResults.Result { + RandomString() { + this(randomUnicodeOfLengthBetween(2, 20)); + } + @Override public Iterator toXContentChunked(ToXContent.Params params) { - var randomString = randomUnicodeOfLengthBetween(2, 20); return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field("delta", randomString).endObject()); } + + @Override + public String getWriteableName() { + return "test_RandomString"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(randomString); + } } private static class SingleInferenceServiceResults implements InferenceServiceResults { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e8dc763116707..563247de44f81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -23,6 +23,8 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; @@ -484,6 +486,20 @@ private static void addInferenceResultsNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 4afafc5adf0c3..2417561cc4497 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -276,7 +275,10 @@ private void inferOnServiceWithMetrics( inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { if (request.isStreaming()) { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + var taskProcessor = streamingTaskManager.create( + STREAMING_INFERENCE_TASK_TYPE, + STREAMING_TASK_ACTION + ); inferenceResults.publisher().subscribe(taskProcessor); var instrumentedStream = new PublisherWithMetrics(timer, model); @@ -295,7 +297,7 @@ private void inferOnServiceWithMetrics( })); } - protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { + protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { return upstream; } @@ -349,7 +351,7 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio ); } - private class PublisherWithMetrics extends DelegatingProcessor { + private class PublisherWithMetrics extends DelegatingProcessor { private final InferenceTimer timer; private final Model model; @@ -360,7 +362,7 @@ private PublisherWithMetrics(InferenceTimer timer, Model model) { } @Override - protected void next(ChunkedToXContent item) { + protected void next(InferenceServiceResults.Result item) { downstream().onNext(item); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 1144a11d86cc9..4c8f03fae9184 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; @@ -103,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act * as {@link UnifiedChatCompletionException}. */ @Override - protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { + protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) { return downstream -> { upstream.subscribe(new Flow.Subscriber<>() { @Override @@ -112,7 +111,7 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(ChunkedToXContent item) { + public void onNext(T item) { downstream.onNext(item); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java index f1cfc84643b1c..4369be6bf6209 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java @@ -15,7 +15,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceServiceResults; import java.time.Instant; import java.util.concurrent.Flow; @@ -23,7 +23,8 @@ public interface AmazonBedrockClient { void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; - Flow.Publisher converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException; + Flow.Publisher converseStream(ConverseStreamRequest converseStreamRequest) + throws ElasticsearchException; void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) throws ElasticsearchException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java index be90fbbd214d0..c0ee8d4620661 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java @@ -26,10 +26,10 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.SpecialPermission; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; import org.reactivestreams.FlowAdapters; @@ -88,7 +88,8 @@ public void converse(ConverseRequest converseRequest, ActionListener converseStream(ConverseStreamRequest request) throws ElasticsearchException { + public Flow.Publisher converseStream(ConverseStreamRequest request) + throws ElasticsearchException { var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool); internalClient.converseStream( request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 48c8132035b50..fcfd8e19004c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -9,8 +9,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -101,7 +101,7 @@ * * */ -public class OpenAiStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { +public class OpenAiStreamingProcessor extends DelegatingProcessor, InferenceServiceResults.Result> { private static final Logger log = LogManager.getLogger(OpenAiStreamingProcessor.class); private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index bfd4456279a8a..59c17e890e9f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -34,7 +33,9 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor< + Deque, + StreamingUnifiedChatCompletionResults.Results> { public static final String FUNCTION_FIELD = "function"; private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java index 05d7d90873a71..6c4788fd5142d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java @@ -10,9 +10,9 @@ import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; -import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.common.socket.SocketAccess; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; @@ -82,7 +82,9 @@ public void executeChatCompletionRequest( this.executeRequest(awsBedrockClient); } - public Flow.Publisher executeStreamChatCompletionRequest(AmazonBedrockBaseClient awsBedrockClient) { + public Flow.Publisher executeStreamChatCompletionRequest( + AmazonBedrockBaseClient awsBedrockClient + ) { var converseStreamRequest = ConverseStreamRequest.builder() .modelId(amazonBedrockModel.model()) .messages(getConverseMessageList(requestEntity.messages())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 56966ca40c478..4562b149f3b37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -308,7 +308,7 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(ChunkedToXContent item) { + public void onNext(InferenceServiceResults.Result item) { } @@ -332,7 +332,7 @@ public void onComplete() { })); } - protected Flow.Publisher mockStreamResponse(Consumer> action) { + protected Flow.Publisher mockStreamResponse(Consumer> action) { mockService(true, Set.of(), listener -> { Flow.Processor taskProcessor = mock(); doAnswer(innerAns -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java index 0f127998f9c54..3fc853bac3bbd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -65,31 +65,31 @@ public void testJsonLiteral() { .parse(parser); // Assertions to verify the parsed object - assertEquals("example_id", chunk.getId()); - assertEquals("example_model", chunk.getModel()); - assertEquals("chat.completion.chunk", chunk.getObject()); - assertNotNull(chunk.getUsage()); - assertEquals(50, chunk.getUsage().completionTokens()); - assertEquals(20, chunk.getUsage().promptTokens()); - assertEquals(70, chunk.getUsage().totalTokens()); + assertEquals("example_id", chunk.id()); + assertEquals("example_model", chunk.model()); + assertEquals("chat.completion.chunk", chunk.object()); + assertNotNull(chunk.usage()); + assertEquals(50, chunk.usage().completionTokens()); + assertEquals(20, chunk.usage().promptTokens()); + assertEquals(70, chunk.usage().totalTokens()); - List choices = chunk.getChoices(); + List choices = chunk.choices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertEquals("example_content", choice.delta().getContent()); - assertNull(choice.delta().getRefusal()); - assertEquals("assistant", choice.delta().getRole()); + assertEquals("example_content", choice.delta().content()); + assertNull(choice.delta().refusal()); + assertEquals("assistant", choice.delta().role()); assertEquals("stop", choice.finishReason()); assertEquals(0, choice.index()); - List toolCalls = choice.delta().getToolCalls(); + List toolCalls = choice.delta().toolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); - assertEquals(1, toolCall.getIndex()); - assertEquals("tool_call_id", toolCall.getId()); - assertEquals("example_function_name", toolCall.getFunction().getName()); - assertEquals("example_arguments", toolCall.getFunction().getArguments()); - assertEquals("function", toolCall.getType()); + assertEquals(1, toolCall.index()); + assertEquals("tool_call_id", toolCall.id()); + assertEquals("example_function_name", toolCall.function().name()); + assertEquals("example_arguments", toolCall.function().arguments()); + assertEquals("function", toolCall.type()); } catch (IOException e) { fail(); } @@ -143,40 +143,40 @@ public void testJsonLiteralCornerCases() { .parse(parser); // Assertions to verify the parsed object - assertEquals("example_id", chunk.getId()); - assertEquals("example_model", chunk.getModel()); - assertEquals("chat.completion.chunk", chunk.getObject()); - assertNull(chunk.getUsage()); + assertEquals("example_id", chunk.id()); + assertEquals("example_model", chunk.model()); + assertEquals("chat.completion.chunk", chunk.object()); + assertNull(chunk.usage()); - List choices = chunk.getChoices(); + List choices = chunk.choices(); assertEquals(2, choices.size()); // First choice assertions StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); - assertNull(firstChoice.delta().getContent()); - assertNull(firstChoice.delta().getRefusal()); - assertEquals("assistant", firstChoice.delta().getRole()); - assertTrue(firstChoice.delta().getToolCalls().isEmpty()); + assertNull(firstChoice.delta().content()); + assertNull(firstChoice.delta().refusal()); + assertEquals("assistant", firstChoice.delta().role()); + assertTrue(firstChoice.delta().toolCalls().isEmpty()); assertNull(firstChoice.finishReason()); assertEquals(0, firstChoice.index()); // Second choice assertions StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); - assertEquals("example_content", secondChoice.delta().getContent()); - assertEquals("example_refusal", secondChoice.delta().getRefusal()); - assertEquals("user", secondChoice.delta().getRole()); + assertEquals("example_content", secondChoice.delta().content()); + assertEquals("example_refusal", secondChoice.delta().refusal()); + assertEquals("user", secondChoice.delta().role()); assertEquals("stop", secondChoice.finishReason()); assertEquals(1, secondChoice.index()); List toolCalls = secondChoice.delta() - .getToolCalls(); + .toolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); - assertEquals(1, toolCall.getIndex()); - assertNull(toolCall.getId()); - assertEquals("example_function_name", toolCall.getFunction().getName()); - assertNull(toolCall.getFunction().getArguments()); - assertEquals("function", toolCall.getType()); + assertEquals(1, toolCall.index()); + assertNull(toolCall.id()); + assertEquals("example_function_name", toolCall.function().name()); + assertNull(toolCall.function().arguments()); + assertEquals("function", toolCall.type()); } catch (IOException e) { fail(); } @@ -221,31 +221,31 @@ public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { .parse(parser); // Assertions to verify the parsed object - assertEquals(chatCompletionChunkId, chunk.getId()); - assertEquals(chatCompletionChunkModel, chunk.getModel()); - assertEquals("chat.completion.chunk", chunk.getObject()); - assertNotNull(chunk.getUsage()); - assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens()); - assertEquals(usagePromptTokens, chunk.getUsage().promptTokens()); - assertEquals(usageTotalTokens, chunk.getUsage().totalTokens()); + assertEquals(chatCompletionChunkId, chunk.id()); + assertEquals(chatCompletionChunkModel, chunk.model()); + assertEquals("chat.completion.chunk", chunk.object()); + assertNotNull(chunk.usage()); + assertEquals(usageCompletionTokens, chunk.usage().completionTokens()); + assertEquals(usagePromptTokens, chunk.usage().promptTokens()); + assertEquals(usageTotalTokens, chunk.usage().totalTokens()); - List choices = chunk.getChoices(); + List choices = chunk.choices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertEquals(choiceContent, choice.delta().getContent()); - assertNull(choice.delta().getRefusal()); - assertEquals(choiceRole, choice.delta().getRole()); + assertEquals(choiceContent, choice.delta().content()); + assertNull(choice.delta().refusal()); + assertEquals(choiceRole, choice.delta().role()); assertEquals(choiceFinishReason, choice.finishReason()); assertEquals(choiceIndex, choice.index()); - List toolCalls = choice.delta().getToolCalls(); + List toolCalls = choice.delta().toolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); - assertEquals(toolCallIndex, toolCall.getIndex()); - assertEquals(toolCallId, toolCall.getId()); - assertEquals(toolCallFunctionName, toolCall.getFunction().getName()); - assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments()); - assertEquals(toolCallType, toolCall.getType()); + assertEquals(toolCallIndex, toolCall.index()); + assertEquals(toolCallId, toolCall.id()); + assertEquals(toolCallFunctionName, toolCall.function().name()); + assertEquals(toolCallFunctionArguments, toolCall.function().arguments()); + assertEquals(toolCallType, toolCall.type()); } } @@ -273,20 +273,20 @@ public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IO .parse(parser); // Assertions to verify the parsed object - assertEquals(chatCompletionChunkId, chunk.getId()); - assertEquals(chatCompletionChunkModel, chunk.getModel()); - assertEquals("chat.completion.chunk", chunk.getObject()); - assertNull(chunk.getUsage()); + assertEquals(chatCompletionChunkId, chunk.id()); + assertEquals(chatCompletionChunkModel, chunk.model()); + assertEquals("chat.completion.chunk", chunk.object()); + assertNull(chunk.usage()); - List choices = chunk.getChoices(); + List choices = chunk.choices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertNull(choice.delta().getContent()); - assertNull(choice.delta().getRefusal()); - assertNull(choice.delta().getRole()); + assertNull(choice.delta().content()); + assertNull(choice.delta().refusal()); + assertNull(choice.delta().role()); assertNull(choice.finishReason()); assertEquals(choiceIndex, choice.index()); - assertTrue(choice.delta().getToolCalls().isEmpty()); + assertTrue(choice.delta().toolCalls().isEmpty()); } } From 66ed99d6552ebf4c7a018355618c7bf3914ae1a8 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 13 Feb 2025 19:22:15 +0000 Subject: [PATCH 2/3] [CI] Auto commit changes from spotless --- .../inference/results/StreamingChatCompletionResults.java | 1 - .../results/StreamingChatCompletionResultsTests.java | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index 7926590ed5326..db1f1337e6959 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -20,7 +20,6 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.Iterator; -import java.util.Objects; import java.util.concurrent.Flow; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java index 08eeb49781bad..6bdf0ff44ad90 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java @@ -13,8 +13,7 @@ import java.io.IOException; import java.util.ArrayDeque; -public class StreamingChatCompletionResultsTests extends AbstractWireSerializingTestCase< - StreamingChatCompletionResults.Results> { +public class StreamingChatCompletionResultsTests extends AbstractWireSerializingTestCase { @Override protected Writeable.Reader instanceReader() { return StreamingChatCompletionResults.Results::new; @@ -23,7 +22,7 @@ protected Writeable.Reader instanceReade @Override protected StreamingChatCompletionResults.Results createTestInstance() { var results = new ArrayDeque(); - for(int i = 0; i < randomIntBetween(1, 10) ; i++) { + for (int i = 0; i < randomIntBetween(1, 10); i++) { results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5))); } return new StreamingChatCompletionResults.Results(results); @@ -32,7 +31,7 @@ protected StreamingChatCompletionResults.Results createTestInstance() { @Override protected StreamingChatCompletionResults.Results mutateInstance(StreamingChatCompletionResults.Results instance) throws IOException { var results = new ArrayDeque<>(instance.results()); - if(randomBoolean()) { + if (randomBoolean()) { results.pop(); } else { results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5))); From cb028d31e4d9c55b52a7dc3ee038ac5a19ba8120 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Wed, 19 Feb 2025 11:04:23 -0500 Subject: [PATCH 3/3] Adding deque utils helper --- .../xpack/core/inference/DequeUtils.java | 53 +++++++++++++++++++ .../StreamingChatCompletionResults.java | 31 ++--------- ...StreamingUnifiedChatCompletionResults.java | 31 ++--------- .../xpack/core/inference/DequeUtilsTests.java | 52 ++++++++++++++++++ 4 files changed, 113 insertions(+), 54 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java new file mode 100644 index 0000000000000..060f66566109a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; + +public final class DequeUtils { + + private DequeUtils() { + // util functions only + } + + public static Deque readDeque(StreamInput in, Writeable.Reader reader) throws IOException { + return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in)))); + } + + public static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { + if (thisDeque.size() != otherDeque.size()) { + return false; + } + var thisIter = thisDeque.iterator(); + var otherIter = otherDeque.iterator(); + while (thisIter.hasNext() && otherIter.hasNext()) { + if (thisIter.next().equals(otherIter.next()) == false) { + return false; + } + } + return true; + } + + public static int dequeHashCode(Deque deque) { + if (deque == null) { + return 0; + } + return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); + } + + public static Deque of(T elem) { + var deque = new ArrayDeque(1); + deque.offer(elem); + return deque; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index db1f1337e6959..7657ad498cadf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -17,11 +17,13 @@ import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Deque; import java.util.Iterator; import java.util.concurrent.Flow; +import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals; +import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode; +import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; /** @@ -40,11 +42,7 @@ public record Results(Deque results) implements InferenceServiceResults. public static final String NAME = "streaming_chat_completion_results"; public Results(StreamInput in) throws IOException { - this(deque(in)); - } - - private static Deque deque(StreamInput in) throws IOException { - return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new Result(stream)))); + this(readDeque(in, Result::new)); } @Override @@ -75,31 +73,10 @@ public boolean equals(Object o) { return dequeEquals(this.results, other.results()); } - private static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { - if (thisDeque.size() != otherDeque.size()) { - return false; - } - var thisIter = thisDeque.iterator(); - var otherIter = otherDeque.iterator(); - while (thisIter.hasNext() && otherIter.hasNext()) { - if (thisIter.next().equals(otherIter.next()) == false) { - return false; - } - } - return true; - } - @Override public int hashCode() { return dequeHashCode(results); } - - private static int dequeHashCode(Deque deque) { - if (deque == null) { - return 0; - } - return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); - } } public record Result(String delta) implements ChunkedToXContent, Writeable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 9c4cc6f2922b4..4604a522c147b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -18,7 +18,6 @@ import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -26,6 +25,9 @@ import java.util.concurrent.Flow; import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk; +import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals; +import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode; +import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque; /** * Chat Completion results that only contain a Flow.Publisher. @@ -69,11 +71,7 @@ public record Results(Deque chunks) implements InferenceSer public static String NAME = "streaming_unified_chat_completion_results"; public Results(StreamInput in) throws IOException { - this(deque(in)); - } - - private static Deque deque(StreamInput in) throws IOException { - return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new ChatCompletionChunk(stream)))); + this(readDeque(in, ChatCompletionChunk::new)); } @Override @@ -98,31 +96,10 @@ public boolean equals(Object o) { return dequeEquals(chunks, results.chunks()); } - private static boolean dequeEquals(Deque thisDeque, Deque otherDeque) { - if (thisDeque.size() != otherDeque.size()) { - return false; - } - var thisIter = thisDeque.iterator(); - var otherIter = otherDeque.iterator(); - while (thisIter.hasNext() && otherIter.hasNext()) { - if (thisIter.next().equals(otherIter.next()) == false) { - return false; - } - } - return true; - } - @Override public int hashCode() { return dequeHashCode(chunks); } - - private static int dequeHashCode(Deque deque) { - if (deque == null) { - return 0; - } - return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); - } } public record ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java new file mode 100644 index 0000000000000..76e71135b531c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.mockito.Mockito.mock; + +public class DequeUtilsTests extends ESTestCase { + + public void testEqualsAndHashCodeWithSameObject() { + var someObject = mock(); + var dequeOne = DequeUtils.of(someObject); + var dequeTwo = DequeUtils.of(someObject); + assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo)); + assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo)); + } + + public void testEqualsAndHashCodeWithEqualsObject() { + var dequeOne = DequeUtils.of("the same string"); + var dequeTwo = DequeUtils.of("the same string"); + assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo)); + assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo)); + } + + public void testNotEqualsAndHashCode() { + var dequeOne = DequeUtils.of(mock()); + var dequeTwo = DequeUtils.of(mock()); + assertFalse(DequeUtils.dequeEquals(dequeOne, dequeTwo)); + assertNotEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo)); + } + + public void testReadFromStream() throws IOException { + var dequeOne = DequeUtils.of("this is a string"); + var out = new BytesStreamOutput(); + out.writeStringCollection(dequeOne); + var in = new ByteArrayStreamInput(out.bytes().array()); + var dequeTwo = DequeUtils.readDeque(in, StreamInput::readString); + assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo)); + assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo)); + } +}