diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
index 34c8ffcb82d09..9b55e9ce54a86 100644
--- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
+++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
@@ -10,8 +10,12 @@
package org.elasticsearch.inference;
import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.xcontent.ToXContent;
+import java.io.IOException;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;
@@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
*
*
For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.
*/
- List extends InferenceResults> transformToCoordinationFormat();
+ default List extends InferenceResults> transformToCoordinationFormat() {
+ throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
+ }
/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
*/
- List extends InferenceResults> transformToLegacyFormat();
+ default List extends InferenceResults> transformToLegacyFormat() {
+ throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
+ }
/**
* Convert the result to a map to aid with test assertions
*/
- Map asMap();
+ default Map asMap() {
+ throw new UnsupportedOperationException("asMap() is not implemented");
+ }
+
+ default String getWriteableName() {
+ assert isStreaming() : "This must be implemented when isStreaming() == false";
+ throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+ }
+
+ default void writeTo(StreamOutput out) throws IOException {
+ assert isStreaming() : "This must be implemented when isStreaming() == false";
+ throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+ }
+
+ default Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
+ assert isStreaming() : "This must be implemented when isStreaming() == false";
+ throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+ }
/**
* Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
@@ -52,8 +77,10 @@ default boolean isStreaming() {
* When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
* Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
*/
- default Flow.Publisher extends ChunkedToXContent> publisher() {
+ default Flow.Publisher extends Result> publisher() {
assert isStreaming() == false : "This must be implemented when isStreaming() == true";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
}
+
+ interface Result extends NamedWriteable, ChunkedToXContent {}
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java
new file mode 100644
index 0000000000000..060f66566109a
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+public final class DequeUtils {
+
+ private DequeUtils() {
+ // util functions only
+ }
+
+ public static Deque readDeque(StreamInput in, Writeable.Reader reader) throws IOException {
+ return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
+ }
+
+ public static boolean dequeEquals(Deque> thisDeque, Deque> otherDeque) {
+ if (thisDeque.size() != otherDeque.size()) {
+ return false;
+ }
+ var thisIter = thisDeque.iterator();
+ var otherIter = otherDeque.iterator();
+ while (thisIter.hasNext() && otherIter.hasNext()) {
+ if (thisIter.next().equals(otherIter.next()) == false) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public static int dequeHashCode(Deque> deque) {
+ if (deque == null) {
+ return 0;
+ }
+ return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
+ }
+
+ public static Deque of(T elem) {
+ var deque = new ArrayDeque(1);
+ deque.offer(elem);
+ return deque;
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index f2b2c563d7519..dc177795af76a 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -16,7 +16,6 @@
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.TimeValue;
@@ -342,7 +341,7 @@ public static class Response extends ActionResponse implements ChunkedToXContent
private final InferenceServiceResults results;
private final boolean isStreaming;
- private final Flow.Publisher publisher;
+ private final Flow.Publisher publisher;
public Response(InferenceServiceResults results) {
this.results = results;
@@ -350,7 +349,7 @@ public Response(InferenceServiceResults results) {
this.publisher = null;
}
- public Response(InferenceServiceResults results, Flow.Publisher publisher) {
+ public Response(InferenceServiceResults results, Flow.Publisher publisher) {
this.results = results;
this.isStreaming = true;
this.publisher = publisher;
@@ -434,7 +433,7 @@ public boolean isStreaming() {
* When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
* If the RestResponse is closed, it will cancel the subscription.
*/
- public Flow.Publisher publisher() {
+ public Flow.Publisher publisher() {
assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
return publisher;
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java
index 05a181d3fc5b6..21455dc3ff7a5 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java
@@ -8,63 +8,43 @@
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
-import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import java.io.IOException;
import java.util.Deque;
import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
import java.util.concurrent.Flow;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
/**
* Chat Completion results that only contain a Flow.Publisher.
*/
-public record StreamingChatCompletionResults(Flow.Publisher extends ChunkedToXContent> publisher) implements InferenceServiceResults {
+public record StreamingChatCompletionResults(Flow.Publisher extends InferenceServiceResults.Result> publisher)
+ implements
+ InferenceServiceResults {
@Override
public boolean isStreaming() {
return true;
}
- @Override
- public List extends InferenceResults> transformToCoordinationFormat() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public List extends InferenceResults> transformToLegacyFormat() {
- throw new UnsupportedOperationException("Not implemented");
- }
+ public record Results(Deque results) implements InferenceServiceResults.Result {
+ public static final String NAME = "streaming_chat_completion_results";
- @Override
- public Map asMap() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public String getWriteableName() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
- throw new UnsupportedOperationException("Not implemented");
- }
+ public Results(StreamInput in) throws IOException {
+ this(readDeque(in, Result::new));
+ }
- public record Results(Deque results) implements ChunkedToXContent {
@Override
public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
@@ -75,11 +55,37 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
ChunkedToXContentHelper.endObject()
);
}
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeCollection(results);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ Results other = (Results) o;
+ return dequeEquals(this.results, other.results());
+ }
+
+ @Override
+ public int hashCode() {
+ return dequeHashCode(results);
+ }
}
- public record Result(String delta) implements ChunkedToXContent {
+ public record Result(String delta) implements ChunkedToXContent, Writeable {
private static final String RESULT = "delta";
+ private Result(StreamInput in) throws IOException {
+ this(in.readString());
+ }
+
@Override
public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
@@ -88,5 +94,10 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
ChunkedToXContentHelper.endObject()
);
}
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(delta);
+ }
}
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java
index 90038c67036c4..d502b012bcdcb 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java
@@ -8,10 +8,12 @@
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
-import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
@@ -20,13 +22,16 @@
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
-import java.util.Map;
import java.util.concurrent.Flow;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
+
/**
* Chat Completion results that only contain a Flow.Publisher.
*/
-public record StreamingUnifiedChatCompletionResults(Flow.Publisher extends ChunkedToXContent> publisher)
+public record StreamingUnifiedChatCompletionResults(Flow.Publisher extends InferenceServiceResults.Result> publisher)
implements
InferenceServiceResults {
@@ -57,76 +62,58 @@ public boolean isStreaming() {
}
@Override
- public List extends InferenceResults> transformToCoordinationFormat() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public List extends InferenceResults> transformToLegacyFormat() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public Map asMap() {
- throw new UnsupportedOperationException("Not implemented");
- }
-
- @Override
- public String getWriteableName() {
+ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException("Not implemented");
}
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- throw new UnsupportedOperationException("Not implemented");
- }
+ public record Results(Deque chunks) implements InferenceServiceResults.Result {
+ public static String NAME = "streaming_unified_chat_completion_results";
- @Override
- public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
- throw new UnsupportedOperationException("Not implemented");
- }
+ public Results(StreamInput in) throws IOException {
+ this(readDeque(in, ChatCompletionChunk::new));
+ }
- public record Results(Deque chunks) implements ChunkedToXContent {
@Override
public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)));
}
- }
-
- public static class ChatCompletionChunk implements ChunkedToXContent {
- private final String id;
-
- public String getId() {
- return id;
- }
- public List getChoices() {
- return choices;
+ @Override
+ public String getWriteableName() {
+ return NAME;
}
- public String getModel() {
- return model;
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeCollection(chunks, StreamOutput::writeWriteable);
}
- public String getObject() {
- return object;
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ Results results = (Results) o;
+ return dequeEquals(chunks, results.chunks());
}
- public Usage getUsage() {
- return usage;
+ @Override
+ public int hashCode() {
+ return dequeHashCode(chunks);
}
+ }
- private final List choices;
- private final String model;
- private final String object;
- private final ChatCompletionChunk.Usage usage;
-
- public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) {
- this.id = id;
- this.choices = choices;
- this.model = model;
- this.object = object;
- this.usage = usage;
+ public record ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage)
+ implements
+ ChunkedToXContent,
+ Writeable {
+
+ private ChatCompletionChunk(StreamInput in) throws IOException {
+ this(
+ in.readString(),
+ in.readOptionalCollectionAsList(Choice::new),
+ in.readString(),
+ in.readString(),
+ in.readOptional(Usage::new)
+ );
}
@Override
@@ -163,7 +150,23 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
);
}
- public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) {
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(id);
+ out.writeOptionalCollection(choices);
+ out.writeString(model);
+ out.writeString(object);
+ out.writeOptionalWriteable(usage);
+ }
+
+ public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index)
+ implements
+ ChunkedToXContentObject,
+ Writeable {
+
+ private Choice(StreamInput in) throws IOException {
+ this(new Delta(in), in.readOptionalString(), in.readInt());
+ }
/*
choices: Array<{
@@ -182,17 +185,22 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
);
}
- public static class Delta {
- private final String content;
- private final String refusal;
- private final String role;
- private List toolCalls;
-
- public Delta(String content, String refusal, String role, List toolCalls) {
- this.content = content;
- this.refusal = refusal;
- this.role = role;
- this.toolCalls = toolCalls;
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeWriteable(delta);
+ out.writeOptionalString(finishReason);
+ out.writeInt(index);
+ }
+
+ public record Delta(String content, String refusal, String role, List toolCalls) implements Writeable {
+
+ private Delta(StreamInput in) throws IOException {
+ this(
+ in.readOptionalString(),
+ in.readOptionalString(),
+ in.readOptionalString(),
+ in.readOptionalCollectionAsList(ToolCall::new)
+ );
}
/*
@@ -224,49 +232,26 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
}
- public String getContent() {
- return content;
- }
-
- public String getRefusal() {
- return refusal;
- }
-
- public String getRole() {
- return role;
- }
-
- public List getToolCalls() {
- return toolCalls;
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalString(content);
+ out.writeOptionalString(refusal);
+ out.writeOptionalString(role);
+ out.writeOptionalCollection(toolCalls);
}
- public static class ToolCall {
- private final int index;
- private final String id;
- public ChatCompletionChunk.Choice.Delta.ToolCall.Function function;
- private final String type;
-
- public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) {
- this.index = index;
- this.id = id;
- this.function = function;
- this.type = type;
- }
-
- public int getIndex() {
- return index;
- }
-
- public String getId() {
- return id;
- }
-
- public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() {
- return function;
- }
-
- public String getType() {
- return type;
+ public record ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type)
+ implements
+ ChunkedToXContentObject,
+ Writeable {
+
+ private ToolCall(StreamInput in) throws IOException {
+ this(
+ in.readInt(),
+ in.readOptionalString(),
+ in.readOptional(ChatCompletionChunk.Choice.Delta.ToolCall.Function::new),
+ in.readOptionalString()
+ );
}
/*
@@ -289,8 +274,8 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
content = Iterators.concat(
content,
ChunkedToXContentHelper.startObject(FUNCTION_FIELD),
- ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()),
- ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.getName()),
+ ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.arguments()),
+ ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.name()),
ChunkedToXContentHelper.endObject()
);
}
@@ -303,27 +288,42 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
return content;
}
- public static class Function {
- private final String arguments;
- private final String name;
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeInt(index);
+ out.writeOptionalString(id);
+ out.writeOptionalWriteable(function);
+ out.writeOptionalString(type);
+ }
- public Function(String arguments, String name) {
- this.arguments = arguments;
- this.name = name;
- }
+ public record Function(String arguments, String name) implements Writeable {
- public String getArguments() {
- return arguments;
+ private Function(StreamInput in) throws IOException {
+ this(in.readOptionalString(), in.readOptionalString());
}
- public String getName() {
- return name;
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalString(arguments);
+ out.writeOptionalString(name);
}
}
}
}
}
- public record Usage(int completionTokens, int promptTokens, int totalTokens) {}
+ public record Usage(int completionTokens, int promptTokens, int totalTokens) implements Writeable {
+ private Usage(StreamInput in) throws IOException {
+ this(in.readInt(), in.readInt(), in.readInt());
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeInt(completionTokens);
+ out.writeInt(promptTokens);
+ out.writeInt(totalTokens);
+ }
+ }
+
}
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java
new file mode 100644
index 0000000000000..76e71135b531c
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference;
+
+import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+import static org.mockito.Mockito.mock;
+
+public class DequeUtilsTests extends ESTestCase {
+
+ public void testEqualsAndHashCodeWithSameObject() {
+ var someObject = mock();
+ var dequeOne = DequeUtils.of(someObject);
+ var dequeTwo = DequeUtils.of(someObject);
+ assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+ assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+ }
+
+ public void testEqualsAndHashCodeWithEqualsObject() {
+ var dequeOne = DequeUtils.of("the same string");
+ var dequeTwo = DequeUtils.of("the same string");
+ assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+ assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+ }
+
+ public void testNotEqualsAndHashCode() {
+ var dequeOne = DequeUtils.of(mock());
+ var dequeTwo = DequeUtils.of(mock());
+ assertFalse(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+ assertNotEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+ }
+
+ public void testReadFromStream() throws IOException {
+ var dequeOne = DequeUtils.of("this is a string");
+ var out = new BytesStreamOutput();
+ out.writeStringCollection(dequeOne);
+ var in = new ByteArrayStreamInput(out.bytes().array());
+ var dequeTwo = DequeUtils.readDeque(in, StreamInput::readString);
+ assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+ assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java
new file mode 100644
index 0000000000000..6bdf0ff44ad90
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+
+public class StreamingChatCompletionResultsTests extends AbstractWireSerializingTestCase {
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return StreamingChatCompletionResults.Results::new;
+ }
+
+ @Override
+ protected StreamingChatCompletionResults.Results createTestInstance() {
+ var results = new ArrayDeque();
+ for (int i = 0; i < randomIntBetween(1, 10); i++) {
+ results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5)));
+ }
+ return new StreamingChatCompletionResults.Results(results);
+ }
+
+ @Override
+ protected StreamingChatCompletionResults.Results mutateInstance(StreamingChatCompletionResults.Results instance) throws IOException {
+ var results = new ArrayDeque<>(instance.results());
+ if (randomBoolean()) {
+ results.pop();
+ } else {
+ results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5)));
+ }
+ return new StreamingChatCompletionResults.Results(results);
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java
index a8f569dbef9d1..669ba2f881fe7 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java
@@ -10,7 +10,8 @@
package org.elasticsearch.xpack.core.inference.results;
import org.elasticsearch.common.Strings;
-import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
@@ -18,8 +19,10 @@
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
+import java.util.function.Supplier;
-public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
+public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
+ StreamingUnifiedChatCompletionResults.Results> {
public void testResults_toXContentChunked() throws IOException {
String expected = """
@@ -195,4 +198,65 @@ public void testToolCallToXContentChunked() throws IOException {
assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
}
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return StreamingUnifiedChatCompletionResults.Results::new;
+ }
+
+ @Override
+ protected StreamingUnifiedChatCompletionResults.Results createTestInstance() {
+ var results = new ArrayDeque();
+ for (int i = 0; i < randomIntBetween(1, 3); i++) {
+ results.offer(randomChatCompletionChunk());
+ }
+ return new StreamingUnifiedChatCompletionResults.Results(results);
+ }
+
+ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() {
+ Supplier randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5);
+ return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
+ randomAlphanumericOfLength(5),
+ randomBoolean() ? null : randomList(randomInt(5), () -> {
+ return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
+ new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
+ randomOptionalString.get(),
+ randomOptionalString.get(),
+ randomOptionalString.get(),
+ randomBoolean() ? null : randomList(randomInt(5), () -> {
+ return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
+ randomInt(5),
+ randomOptionalString.get(),
+ randomBoolean()
+ ? null
+ : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
+ randomOptionalString.get(),
+ randomOptionalString.get()
+ ),
+ randomOptionalString.get()
+ );
+ })
+ ),
+ randomOptionalString.get(),
+ randomInt(5)
+ );
+ }),
+ randomAlphanumericOfLength(5),
+ randomAlphanumericOfLength(5),
+ randomBoolean()
+ ? null
+ : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5))
+ );
+ }
+
+ @Override
+ protected StreamingUnifiedChatCompletionResults.Results mutateInstance(StreamingUnifiedChatCompletionResults.Results instance)
+ throws IOException {
+ var results = new ArrayDeque<>(instance.chunks());
+ if (randomBoolean()) {
+ results.pop();
+ } else {
+ results.add(randomChatCompletionChunk());
+ }
+ return new StreamingUnifiedChatCompletionResults.Results(results); // immutable
+ }
}
diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
index b0e43c8607078..9bbed11ac506d 100644
--- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
+++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
@@ -11,11 +11,9 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.LazyInitializable;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
@@ -31,6 +29,7 @@
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
@@ -39,6 +38,7 @@
import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -158,16 +158,31 @@ public void cancel() {}
});
}
- private ChunkedToXContent completionChunk(String delta) {
- return params -> Iterators.concat(
- ChunkedToXContentHelper.startObject(),
- ChunkedToXContentHelper.startArray(COMPLETION),
- ChunkedToXContentHelper.startObject(),
- ChunkedToXContentHelper.field("delta", delta),
- ChunkedToXContentHelper.endObject(),
- ChunkedToXContentHelper.endArray(),
- ChunkedToXContentHelper.endObject()
- );
+ private InferenceServiceResults.Result completionChunk(String delta) {
+ return new InferenceServiceResults.Result() {
+ @Override
+ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
+ return ChunkedToXContentHelper.singleChunk(
+ (b, p) -> b.startObject()
+ .startArray(COMPLETION)
+ .startObject()
+ .field("delta", delta)
+ .endObject()
+ .endArray()
+ .endObject()
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(delta);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return "test_completionChunk";
+ }
+ };
}
private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) {
@@ -205,22 +220,37 @@ public void cancel() {}
"object": "chat.completion.chunk"
}
*/
- private ChunkedToXContent unifiedCompletionChunk(String delta) {
- return params -> Iterators.concat(
- ChunkedToXContentHelper.startObject(),
- ChunkedToXContentHelper.field("id", "id"),
- ChunkedToXContentHelper.startArray("choices"),
- ChunkedToXContentHelper.startObject(),
- ChunkedToXContentHelper.startObject("delta"),
- ChunkedToXContentHelper.field("content", delta),
- ChunkedToXContentHelper.endObject(),
- ChunkedToXContentHelper.field("index", 0),
- ChunkedToXContentHelper.endObject(),
- ChunkedToXContentHelper.endArray(),
- ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"),
- ChunkedToXContentHelper.field("object", "chat.completion.chunk"),
- ChunkedToXContentHelper.endObject()
- );
+ private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
+ return new InferenceServiceResults.Result() {
+ @Override
+ public String getWriteableName() {
+ return "test_unifiedCompletionChunk";
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(delta);
+ }
+
+ @Override
+ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
+ return ChunkedToXContentHelper.singleChunk(
+ (b, p) -> b.startObject()
+ .field("id", "id")
+ .startArray("choices")
+ .startObject()
+ .startObject("delta")
+ .field("content", delta)
+ .endObject()
+ .field("index", 0)
+ .endObject()
+ .endArray()
+ .field("model", "gpt-4o-2024-08-06")
+ .field("object", "chat.completion.chunk")
+ .endObject()
+ );
+ }
+ };
}
@Override
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java
index af1acc7530dce..f837ff5c4049d 100644
--- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java
@@ -34,6 +34,7 @@
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
@@ -178,14 +179,14 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}
private static class StreamingInferenceServiceResults implements InferenceServiceResults {
- private final Flow.Publisher publisher;
+ private final Flow.Publisher publisher;
- private StreamingInferenceServiceResults(Flow.Publisher publisher) {
+ private StreamingInferenceServiceResults(Flow.Publisher publisher) {
this.publisher = publisher;
}
@Override
- public Flow.Publisher publisher() {
+ public Flow.Publisher publisher() {
return publisher;
}
@@ -223,7 +224,7 @@ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params
}
}
- private static class RandomPublisher implements Flow.Publisher {
+ private static class RandomPublisher implements Flow.Publisher {
private final int requestCount;
private final boolean withError;
@@ -233,7 +234,7 @@ private RandomPublisher(int requestCount, boolean withError) {
}
@Override
- public void subscribe(Flow.Subscriber super ChunkedToXContent> subscriber) {
+ public void subscribe(Flow.Subscriber super InferenceServiceResults.Result> subscriber) {
var resultCount = new AtomicInteger(requestCount);
subscriber.onSubscribe(new Flow.Subscription() {
@Override
@@ -255,11 +256,24 @@ public void cancel() {
}
}
- private static class RandomString implements ChunkedToXContent {
+ private record RandomString(String randomString) implements InferenceServiceResults.Result {
+ RandomString() {
+ this(randomUnicodeOfLengthBetween(2, 20));
+ }
+
@Override
public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
- var randomString = randomUnicodeOfLengthBetween(2, 20);
- return ChunkedToXContent.builder(params).object(b -> b.field("delta", randomString));
+ return ChunkedToXContentHelper.singleChunk((b, p) -> b.startObject().field("delta", randomString).endObject());
+ }
+
+ @Override
+ public String getWriteableName() {
+ return "test_RandomString";
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(randomString);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index e8dc763116707..563247de44f81 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -23,6 +23,8 @@
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
@@ -484,6 +486,20 @@ private static void addInferenceResultsNamedWriteables(List namedWriteables) {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
index b390a51f6d3e2..3552fd8cacdf8 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
@@ -19,7 +19,6 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
@@ -276,7 +275,10 @@ private void inferOnServiceWithMetrics(
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
if (request.isStreaming()) {
- var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
+ var taskProcessor = streamingTaskManager.create(
+ STREAMING_INFERENCE_TASK_TYPE,
+ STREAMING_TASK_ACTION
+ );
inferenceResults.publisher().subscribe(taskProcessor);
var instrumentedStream = new PublisherWithMetrics(timer, model);
@@ -295,7 +297,7 @@ private void inferOnServiceWithMetrics(
}));
}
- protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) {
+ protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) {
return upstream;
}
@@ -349,7 +351,7 @@ private static ElasticsearchStatusException requestModelTaskTypeMismatchExceptio
);
}
- private class PublisherWithMetrics extends DelegatingProcessor {
+ private class PublisherWithMetrics extends DelegatingProcessor {
private final InferenceTimer timer;
private final Model model;
@@ -360,7 +362,7 @@ private PublisherWithMetrics(InferenceTimer timer, Model model) {
}
@Override
- protected void next(ChunkedToXContent item) {
+ protected void next(InferenceServiceResults.Result item) {
downstream().onNext(item);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
index 1144a11d86cc9..4c8f03fae9184 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
@@ -11,7 +11,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.node.NodeClient;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
@@ -103,7 +102,7 @@ protected void doExecute(Task task, UnifiedCompletionAction.Request request, Act
* as {@link UnifiedChatCompletionException}.
*/
@Override
- protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) {
+ protected Flow.Publisher streamErrorHandler(Flow.Processor upstream) {
return downstream -> {
upstream.subscribe(new Flow.Subscriber<>() {
@Override
@@ -112,7 +111,7 @@ public void onSubscribe(Flow.Subscription subscription) {
}
@Override
- public void onNext(ChunkedToXContent item) {
+ public void onNext(T item) {
downstream.onNext(item);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java
index f1cfc84643b1c..4369be6bf6209 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java
@@ -15,7 +15,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.inference.InferenceServiceResults;
import java.time.Instant;
import java.util.concurrent.Flow;
@@ -23,7 +23,8 @@
public interface AmazonBedrockClient {
void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException;
- Flow.Publisher extends ChunkedToXContent> converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException;
+ Flow.Publisher extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest converseStreamRequest)
+ throws ElasticsearchException;
void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener)
throws ElasticsearchException;
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java
index be90fbbd214d0..c0ee8d4620661 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java
@@ -26,10 +26,10 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
import org.reactivestreams.FlowAdapters;
@@ -88,7 +88,8 @@ public void converse(ConverseRequest converseRequest, ActionListener converseStream(ConverseStreamRequest request) throws ElasticsearchException {
+ public Flow.Publisher extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest request)
+ throws ElasticsearchException {
var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool);
internalClient.converseStream(
request,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java
index 48c8132035b50..fcfd8e19004c5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java
@@ -9,8 +9,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
@@ -101,7 +101,7 @@
*
*
*/
-public class OpenAiStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> {
+public class OpenAiStreamingProcessor extends DelegatingProcessor, InferenceServiceResults.Result> {
private static final Logger log = LogManager.getLogger(OpenAiStreamingProcessor.class);
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response";
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java
index bfd4456279a8a..59c17e890e9f5 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java
@@ -9,7 +9,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
@@ -34,7 +33,9 @@
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
-public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> {
+public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
+ Deque,
+ StreamingUnifiedChatCompletionResults.Results> {
public static final String FUNCTION_FIELD = "function";
private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java
index 05d7d90873a71..6c4788fd5142d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java
@@ -10,9 +10,9 @@
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.common.socket.SocketAccess;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient;
@@ -82,7 +82,9 @@ public void executeChatCompletionRequest(
this.executeRequest(awsBedrockClient);
}
- public Flow.Publisher extends ChunkedToXContent> executeStreamChatCompletionRequest(AmazonBedrockBaseClient awsBedrockClient) {
+ public Flow.Publisher extends InferenceServiceResults.Result> executeStreamChatCompletionRequest(
+ AmazonBedrockBaseClient awsBedrockClient
+ ) {
var converseStreamRequest = ConverseStreamRequest.builder()
.modelId(amazonBedrockModel.model())
.messages(getConverseMessageList(requestEntity.messages()));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
index 56966ca40c478..4562b149f3b37 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
@@ -308,7 +308,7 @@ public void onSubscribe(Flow.Subscription subscription) {
}
@Override
- public void onNext(ChunkedToXContent item) {
+ public void onNext(InferenceServiceResults.Result item) {
}
@@ -332,7 +332,7 @@ public void onComplete() {
}));
}
- protected Flow.Publisher mockStreamResponse(Consumer> action) {
+ protected Flow.Publisher mockStreamResponse(Consumer> action) {
mockService(true, Set.of(), listener -> {
Flow.Processor taskProcessor = mock();
doAnswer(innerAns -> {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java
index 0f127998f9c54..3fc853bac3bbd 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java
@@ -65,31 +65,31 @@ public void testJsonLiteral() {
.parse(parser);
// Assertions to verify the parsed object
- assertEquals("example_id", chunk.getId());
- assertEquals("example_model", chunk.getModel());
- assertEquals("chat.completion.chunk", chunk.getObject());
- assertNotNull(chunk.getUsage());
- assertEquals(50, chunk.getUsage().completionTokens());
- assertEquals(20, chunk.getUsage().promptTokens());
- assertEquals(70, chunk.getUsage().totalTokens());
+ assertEquals("example_id", chunk.id());
+ assertEquals("example_model", chunk.model());
+ assertEquals("chat.completion.chunk", chunk.object());
+ assertNotNull(chunk.usage());
+ assertEquals(50, chunk.usage().completionTokens());
+ assertEquals(20, chunk.usage().promptTokens());
+ assertEquals(70, chunk.usage().totalTokens());
- List choices = chunk.getChoices();
+ List choices = chunk.choices();
assertEquals(1, choices.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
- assertEquals("example_content", choice.delta().getContent());
- assertNull(choice.delta().getRefusal());
- assertEquals("assistant", choice.delta().getRole());
+ assertEquals("example_content", choice.delta().content());
+ assertNull(choice.delta().refusal());
+ assertEquals("assistant", choice.delta().role());
assertEquals("stop", choice.finishReason());
assertEquals(0, choice.index());
- List toolCalls = choice.delta().getToolCalls();
+ List toolCalls = choice.delta().toolCalls();
assertEquals(1, toolCalls.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
- assertEquals(1, toolCall.getIndex());
- assertEquals("tool_call_id", toolCall.getId());
- assertEquals("example_function_name", toolCall.getFunction().getName());
- assertEquals("example_arguments", toolCall.getFunction().getArguments());
- assertEquals("function", toolCall.getType());
+ assertEquals(1, toolCall.index());
+ assertEquals("tool_call_id", toolCall.id());
+ assertEquals("example_function_name", toolCall.function().name());
+ assertEquals("example_arguments", toolCall.function().arguments());
+ assertEquals("function", toolCall.type());
} catch (IOException e) {
fail();
}
@@ -143,40 +143,40 @@ public void testJsonLiteralCornerCases() {
.parse(parser);
// Assertions to verify the parsed object
- assertEquals("example_id", chunk.getId());
- assertEquals("example_model", chunk.getModel());
- assertEquals("chat.completion.chunk", chunk.getObject());
- assertNull(chunk.getUsage());
+ assertEquals("example_id", chunk.id());
+ assertEquals("example_model", chunk.model());
+ assertEquals("chat.completion.chunk", chunk.object());
+ assertNull(chunk.usage());
- List choices = chunk.getChoices();
+ List choices = chunk.choices();
assertEquals(2, choices.size());
// First choice assertions
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0);
- assertNull(firstChoice.delta().getContent());
- assertNull(firstChoice.delta().getRefusal());
- assertEquals("assistant", firstChoice.delta().getRole());
- assertTrue(firstChoice.delta().getToolCalls().isEmpty());
+ assertNull(firstChoice.delta().content());
+ assertNull(firstChoice.delta().refusal());
+ assertEquals("assistant", firstChoice.delta().role());
+ assertTrue(firstChoice.delta().toolCalls().isEmpty());
assertNull(firstChoice.finishReason());
assertEquals(0, firstChoice.index());
// Second choice assertions
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1);
- assertEquals("example_content", secondChoice.delta().getContent());
- assertEquals("example_refusal", secondChoice.delta().getRefusal());
- assertEquals("user", secondChoice.delta().getRole());
+ assertEquals("example_content", secondChoice.delta().content());
+ assertEquals("example_refusal", secondChoice.delta().refusal());
+ assertEquals("user", secondChoice.delta().role());
assertEquals("stop", secondChoice.finishReason());
assertEquals(1, secondChoice.index());
List toolCalls = secondChoice.delta()
- .getToolCalls();
+ .toolCalls();
assertEquals(1, toolCalls.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
- assertEquals(1, toolCall.getIndex());
- assertNull(toolCall.getId());
- assertEquals("example_function_name", toolCall.getFunction().getName());
- assertNull(toolCall.getFunction().getArguments());
- assertEquals("function", toolCall.getType());
+ assertEquals(1, toolCall.index());
+ assertNull(toolCall.id());
+ assertEquals("example_function_name", toolCall.function().name());
+ assertNull(toolCall.function().arguments());
+ assertEquals("function", toolCall.type());
} catch (IOException e) {
fail();
}
@@ -221,31 +221,31 @@ public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException {
.parse(parser);
// Assertions to verify the parsed object
- assertEquals(chatCompletionChunkId, chunk.getId());
- assertEquals(chatCompletionChunkModel, chunk.getModel());
- assertEquals("chat.completion.chunk", chunk.getObject());
- assertNotNull(chunk.getUsage());
- assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens());
- assertEquals(usagePromptTokens, chunk.getUsage().promptTokens());
- assertEquals(usageTotalTokens, chunk.getUsage().totalTokens());
+ assertEquals(chatCompletionChunkId, chunk.id());
+ assertEquals(chatCompletionChunkModel, chunk.model());
+ assertEquals("chat.completion.chunk", chunk.object());
+ assertNotNull(chunk.usage());
+ assertEquals(usageCompletionTokens, chunk.usage().completionTokens());
+ assertEquals(usagePromptTokens, chunk.usage().promptTokens());
+ assertEquals(usageTotalTokens, chunk.usage().totalTokens());
- List choices = chunk.getChoices();
+ List choices = chunk.choices();
assertEquals(1, choices.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
- assertEquals(choiceContent, choice.delta().getContent());
- assertNull(choice.delta().getRefusal());
- assertEquals(choiceRole, choice.delta().getRole());
+ assertEquals(choiceContent, choice.delta().content());
+ assertNull(choice.delta().refusal());
+ assertEquals(choiceRole, choice.delta().role());
assertEquals(choiceFinishReason, choice.finishReason());
assertEquals(choiceIndex, choice.index());
- List toolCalls = choice.delta().getToolCalls();
+ List toolCalls = choice.delta().toolCalls();
assertEquals(1, toolCalls.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
- assertEquals(toolCallIndex, toolCall.getIndex());
- assertEquals(toolCallId, toolCall.getId());
- assertEquals(toolCallFunctionName, toolCall.getFunction().getName());
- assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments());
- assertEquals(toolCallType, toolCall.getType());
+ assertEquals(toolCallIndex, toolCall.index());
+ assertEquals(toolCallId, toolCall.id());
+ assertEquals(toolCallFunctionName, toolCall.function().name());
+ assertEquals(toolCallFunctionArguments, toolCall.function().arguments());
+ assertEquals(toolCallType, toolCall.type());
}
}
@@ -273,20 +273,20 @@ public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IO
.parse(parser);
// Assertions to verify the parsed object
- assertEquals(chatCompletionChunkId, chunk.getId());
- assertEquals(chatCompletionChunkModel, chunk.getModel());
- assertEquals("chat.completion.chunk", chunk.getObject());
- assertNull(chunk.getUsage());
+ assertEquals(chatCompletionChunkId, chunk.id());
+ assertEquals(chatCompletionChunkModel, chunk.model());
+ assertEquals("chat.completion.chunk", chunk.object());
+ assertNull(chunk.usage());
- List choices = chunk.getChoices();
+ List choices = chunk.choices();
assertEquals(1, choices.size());
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
- assertNull(choice.delta().getContent());
- assertNull(choice.delta().getRefusal());
- assertNull(choice.delta().getRole());
+ assertNull(choice.delta().content());
+ assertNull(choice.delta().refusal());
+ assertNull(choice.delta().role());
assertNull(choice.finishReason());
assertEquals(choiceIndex, choice.index());
- assertTrue(choice.delta().getToolCalls().isEmpty());
+ assertTrue(choice.delta().toolCalls().isEmpty());
}
}