Skip to content

Commit beb5e38

Browse files
committed
[ML] Make Streaming Results Writeable
Make streaming elements extend Writeable and create StreamInput constructors so we can publish elements across nodes using the transport layer. Additional notes: - Moved optional methods into the InferenceServiceResults interface and default them - StreamingUnifiedChatCompletionResults elements are now all records
1 parent fda7fc7 commit beb5e38

File tree

18 files changed

+529
-271
lines changed

18 files changed

+529
-271
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.common.io.stream.NamedWriteable;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
1314
import org.elasticsearch.common.xcontent.ChunkedToXContent;
15+
import org.elasticsearch.xcontent.ToXContent;
1416

17+
import java.io.IOException;
18+
import java.util.Iterator;
1519
import java.util.List;
1620
import java.util.Map;
1721
import java.util.concurrent.Flow;
@@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
2731
*
2832
* <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
2933
*/
30-
List<? extends InferenceResults> transformToCoordinationFormat();
34+
default List<? extends InferenceResults> transformToCoordinationFormat() {
35+
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
36+
}
3137

3238
/**
3339
* Transform the result to match the format required for versions prior to
3440
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
3541
*/
36-
List<? extends InferenceResults> transformToLegacyFormat();
42+
default List<? extends InferenceResults> transformToLegacyFormat() {
43+
throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
44+
}
3745

3846
/**
3947
* Convert the result to a map to aid with test assertions
4048
*/
41-
Map<String, Object> asMap();
49+
default Map<String, Object> asMap() {
50+
throw new UnsupportedOperationException("asMap() is not implemented");
51+
}
52+
53+
default String getWriteableName() {
54+
assert isStreaming() : "This must be implemented when isStreaming() == false";
55+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
56+
}
57+
58+
default void writeTo(StreamOutput out) throws IOException {
59+
assert isStreaming() : "This must be implemented when isStreaming() == false";
60+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
61+
}
62+
63+
default Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
64+
assert isStreaming() : "This must be implemented when isStreaming() == false";
65+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
66+
}
4267

4368
/**
4469
* Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
@@ -52,8 +77,10 @@ default boolean isStreaming() {
5277
* When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
5378
* Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
5479
*/
55-
default Flow.Publisher<? extends ChunkedToXContent> publisher() {
80+
default Flow.Publisher<? extends Result> publisher() {
5681
assert isStreaming() == false : "This must be implemented when isStreaming() == true";
5782
throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
5883
}
84+
85+
interface Result extends NamedWriteable, ChunkedToXContent {}
5986
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.common.collect.Iterators;
1717
import org.elasticsearch.common.io.stream.StreamInput;
1818
import org.elasticsearch.common.io.stream.StreamOutput;
19-
import org.elasticsearch.common.xcontent.ChunkedToXContent;
2019
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
2120
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
2221
import org.elasticsearch.core.TimeValue;
@@ -342,15 +341,15 @@ public static class Response extends ActionResponse implements ChunkedToXContent
342341

343342
private final InferenceServiceResults results;
344343
private final boolean isStreaming;
345-
private final Flow.Publisher<ChunkedToXContent> publisher;
344+
private final Flow.Publisher<InferenceServiceResults.Result> publisher;
346345

347346
public Response(InferenceServiceResults results) {
348347
this.results = results;
349348
this.isStreaming = false;
350349
this.publisher = null;
351350
}
352351

353-
public Response(InferenceServiceResults results, Flow.Publisher<ChunkedToXContent> publisher) {
352+
public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
354353
this.results = results;
355354
this.isStreaming = true;
356355
this.publisher = publisher;
@@ -434,7 +433,7 @@ public boolean isStreaming() {
434433
* When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
435434
* If the RestResponse is closed, it will cancel the subscription.
436435
*/
437-
public Flow.Publisher<ChunkedToXContent> publisher() {
436+
public Flow.Publisher<InferenceServiceResults.Result> publisher() {
438437
assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
439438
return publisher;
440439
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,46 @@
88
package org.elasticsearch.xpack.core.inference.results;
99

1010
import org.elasticsearch.common.collect.Iterators;
11+
import org.elasticsearch.common.io.stream.StreamInput;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.common.io.stream.Writeable;
1214
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1315
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14-
import org.elasticsearch.inference.InferenceResults;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.xcontent.ToXContent;
1718

1819
import java.io.IOException;
20+
import java.util.ArrayDeque;
1921
import java.util.Deque;
2022
import java.util.Iterator;
21-
import java.util.List;
22-
import java.util.Map;
23+
import java.util.Objects;
2324
import java.util.concurrent.Flow;
2425

2526
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
2627

2728
/**
2829
* Chat Completion results that only contain a Flow.Publisher.
2930
*/
30-
public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults {
31+
public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
32+
implements
33+
InferenceServiceResults {
3134

3235
@Override
3336
public boolean isStreaming() {
3437
return true;
3538
}
3639

37-
@Override
38-
public List<? extends InferenceResults> transformToCoordinationFormat() {
39-
throw new UnsupportedOperationException("Not implemented");
40-
}
41-
42-
@Override
43-
public List<? extends InferenceResults> transformToLegacyFormat() {
44-
throw new UnsupportedOperationException("Not implemented");
45-
}
40+
public record Results(Deque<Result> results) implements InferenceServiceResults.Result {
41+
public static final String NAME = "streaming_chat_completion_results";
4642

47-
@Override
48-
public Map<String, Object> asMap() {
49-
throw new UnsupportedOperationException("Not implemented");
50-
}
51-
52-
@Override
53-
public String getWriteableName() {
54-
throw new UnsupportedOperationException("Not implemented");
55-
}
56-
57-
@Override
58-
public void writeTo(StreamOutput out) throws IOException {
59-
throw new UnsupportedOperationException("Not implemented");
60-
}
43+
public Results(StreamInput in) throws IOException {
44+
this(deque(in));
45+
}
6146

62-
@Override
63-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
64-
throw new UnsupportedOperationException("Not implemented");
65-
}
47+
private static Deque<Result> deque(StreamInput in) throws IOException {
48+
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new Result(stream))));
49+
}
6650

67-
public record Results(Deque<Result> results) implements ChunkedToXContent {
6851
@Override
6952
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
7053
return Iterators.concat(
@@ -75,14 +58,66 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
7558
ChunkedToXContentHelper.endObject()
7659
);
7760
}
61+
62+
@Override
63+
public String getWriteableName() {
64+
return NAME;
65+
}
66+
67+
@Override
68+
public void writeTo(StreamOutput out) throws IOException {
69+
out.writeCollection(results);
70+
}
71+
72+
@Override
73+
public boolean equals(Object o) {
74+
if (o == null || getClass() != o.getClass()) return false;
75+
Results other = (Results) o;
76+
return dequeEquals(this.results, other.results());
77+
}
78+
79+
private static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
80+
if (thisDeque.size() != otherDeque.size()) {
81+
return false;
82+
}
83+
var thisIter = thisDeque.iterator();
84+
var otherIter = otherDeque.iterator();
85+
while (thisIter.hasNext() && otherIter.hasNext()) {
86+
if (thisIter.next().equals(otherIter.next()) == false) {
87+
return false;
88+
}
89+
}
90+
return true;
91+
}
92+
93+
@Override
94+
public int hashCode() {
95+
return dequeHashCode(results);
96+
}
97+
98+
private static int dequeHashCode(Deque<?> deque) {
99+
if (deque == null) {
100+
return 0;
101+
}
102+
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
103+
}
78104
}
79105

80-
public record Result(String delta) implements ChunkedToXContent {
106+
public record Result(String delta) implements ChunkedToXContent, Writeable {
81107
private static final String RESULT = "delta";
82108

109+
private Result(StreamInput in) throws IOException {
110+
this(in.readString());
111+
}
112+
83113
@Override
84114
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
85115
return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field(RESULT, delta).endObject());
86116
}
117+
118+
@Override
119+
public void writeTo(StreamOutput out) throws IOException {
120+
out.writeString(delta);
121+
}
87122
}
88123
}

0 commit comments

Comments
 (0)