-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ML] Make Streaming Results Writeable #122527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
beb5e38
66ed99d
105e35e
cb028d3
2eab022
5c6d7a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,63 +8,45 @@ | |
| package org.elasticsearch.xpack.core.inference.results; | ||
|
|
||
| import org.elasticsearch.common.collect.Iterators; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.io.stream.Writeable; | ||
| import org.elasticsearch.common.xcontent.ChunkedToXContent; | ||
| import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; | ||
| import org.elasticsearch.inference.InferenceResults; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.xcontent.ToXContent; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayDeque; | ||
| import java.util.Deque; | ||
| import java.util.Iterator; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.concurrent.Flow; | ||
|
|
||
| import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; | ||
|
|
||
| /** | ||
| * Chat Completion results that only contain a Flow.Publisher. | ||
| */ | ||
| public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults { | ||
| public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher) | ||
| implements | ||
| InferenceServiceResults { | ||
|
|
||
| @Override | ||
| public boolean isStreaming() { | ||
| return true; | ||
| } | ||
|
|
||
| @Override | ||
| public List<? extends InferenceResults> transformToCoordinationFormat() { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
|
|
||
| @Override | ||
| public List<? extends InferenceResults> transformToLegacyFormat() { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
| public record Results(Deque<Result> results) implements InferenceServiceResults.Result { | ||
| public static final String NAME = "streaming_chat_completion_results"; | ||
|
|
||
| @Override | ||
| public Map<String, Object> asMap() { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
|
|
||
| @Override | ||
| public String getWriteableName() { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
| public Results(StreamInput in) throws IOException { | ||
| this(deque(in)); | ||
| } | ||
|
|
||
| @Override | ||
| public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) { | ||
| throw new UnsupportedOperationException("Not implemented"); | ||
| } | ||
| private static Deque<Result> deque(StreamInput in) throws IOException { | ||
| return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new Result(stream)))); | ||
| } | ||
|
|
||
| public record Results(Deque<Result> results) implements ChunkedToXContent { | ||
| @Override | ||
| public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) { | ||
| return Iterators.concat( | ||
|
|
@@ -75,14 +57,66 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params | |
| ChunkedToXContentHelper.endObject() | ||
| ); | ||
| } | ||
|
|
||
| @Override | ||
| public String getWriteableName() { | ||
| return NAME; | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeCollection(results); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| Results other = (Results) o; | ||
| return dequeEquals(this.results, other.results()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah was this why you were saying deque was a bad idea haha?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it took me a few minutes to realize most of the Deque implementations do not have an equals method =( |
||
| } | ||
|
|
||
| private static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) { | ||
| if (thisDeque.size() != otherDeque.size()) { | ||
| return false; | ||
| } | ||
| var thisIter = thisDeque.iterator(); | ||
| var otherIter = otherDeque.iterator(); | ||
| while (thisIter.hasNext() && otherIter.hasNext()) { | ||
| if (thisIter.next().equals(otherIter.next()) == false) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return dequeHashCode(results); | ||
| } | ||
|
|
||
| private static int dequeHashCode(Deque<?> deque) { | ||
| if (deque == null) { | ||
| return 0; | ||
| } | ||
| return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum); | ||
| } | ||
| } | ||
|
|
||
| public record Result(String delta) implements ChunkedToXContent { | ||
| public record Result(String delta) implements ChunkedToXContent, Writeable { | ||
| private static final String RESULT = "delta"; | ||
|
|
||
| private Result(StreamInput in) throws IOException { | ||
| this(in.readString()); | ||
| } | ||
|
|
||
| @Override | ||
| public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) { | ||
| return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field(RESULT, delta).endObject()); | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeString(delta); | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️