Skip to content

Commit ecaf94d

Browse files
prwhelanelasticsearchmachine
andauthored
[ML] Make Streaming Results Writeable (#122527) (#123139)
* [ML] Make Streaming Results Writeable (#122527) Make streaming elements extend Writeable and create StreamInput constructors so we can publish elements across nodes using the transport layer. Additional notes: - Moved optional methods into the InferenceServiceResults interface and default them - StreamingUnifiedChatCompletionResults elements are now all records * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent d643e3c commit ecaf94d

File tree

20 files changed

+596
-283
lines changed

20 files changed

+596
-283
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.common.io.stream.NamedWriteable;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
1314
import org.elasticsearch.common.xcontent.ChunkedToXContent;
15+
import org.elasticsearch.xcontent.ToXContent;
1416

17+
import java.io.IOException;
18+
import java.util.Iterator;
1519
import java.util.List;
1620
import java.util.Map;
1721
import java.util.concurrent.Flow;
@@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
2731
*
2832
* <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
2933
*/
30-
List<? extends InferenceResults> transformToCoordinationFormat();
34+
default List<? extends InferenceResults> transformToCoordinationFormat() {
35+
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
36+
}
3137

3238
/**
3339
* Transform the result to match the format required for versions prior to
3440
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
3541
*/
36-
List<? extends InferenceResults> transformToLegacyFormat();
42+
default List<? extends InferenceResults> transformToLegacyFormat() {
43+
throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
44+
}
3745

3846
/**
3947
* Convert the result to a map to aid with test assertions
4048
*/
41-
Map<String, Object> asMap();
49+
default Map<String, Object> asMap() {
50+
throw new UnsupportedOperationException("asMap() is not implemented");
51+
}
52+
53+
default String getWriteableName() {
54+
assert isStreaming() : "This must be implemented when isStreaming() == false";
55+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
56+
}
57+
58+
default void writeTo(StreamOutput out) throws IOException {
59+
assert isStreaming() : "This must be implemented when isStreaming() == false";
60+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
61+
}
62+
63+
default Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
64+
assert isStreaming() : "This must be implemented when isStreaming() == false";
65+
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
66+
}
4267

4368
/**
4469
* Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
@@ -52,8 +77,10 @@ default boolean isStreaming() {
5277
* When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
5378
* Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
5479
*/
55-
default Flow.Publisher<? extends ChunkedToXContent> publisher() {
80+
default Flow.Publisher<? extends Result> publisher() {
5681
assert isStreaming() == false : "This must be implemented when isStreaming() == true";
5782
throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
5883
}
84+
85+
interface Result extends NamedWriteable, ChunkedToXContent {}
5986
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.Writeable;
12+
13+
import java.io.IOException;
14+
import java.util.ArrayDeque;
15+
import java.util.Deque;
16+
17+
public final class DequeUtils {
18+
19+
private DequeUtils() {
20+
// util functions only
21+
}
22+
23+
public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
24+
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
25+
}
26+
27+
public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
28+
if (thisDeque.size() != otherDeque.size()) {
29+
return false;
30+
}
31+
var thisIter = thisDeque.iterator();
32+
var otherIter = otherDeque.iterator();
33+
while (thisIter.hasNext() && otherIter.hasNext()) {
34+
if (thisIter.next().equals(otherIter.next()) == false) {
35+
return false;
36+
}
37+
}
38+
return true;
39+
}
40+
41+
public static int dequeHashCode(Deque<?> deque) {
42+
if (deque == null) {
43+
return 0;
44+
}
45+
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
46+
}
47+
48+
public static <T> Deque<T> of(T elem) {
49+
var deque = new ArrayDeque<T>(1);
50+
deque.offer(elem);
51+
return deque;
52+
}
53+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.common.collect.Iterators;
1717
import org.elasticsearch.common.io.stream.StreamInput;
1818
import org.elasticsearch.common.io.stream.StreamOutput;
19-
import org.elasticsearch.common.xcontent.ChunkedToXContent;
2019
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
2120
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
2221
import org.elasticsearch.core.TimeValue;
@@ -342,15 +341,15 @@ public static class Response extends ActionResponse implements ChunkedToXContent
342341

343342
private final InferenceServiceResults results;
344343
private final boolean isStreaming;
345-
private final Flow.Publisher<ChunkedToXContent> publisher;
344+
private final Flow.Publisher<InferenceServiceResults.Result> publisher;
346345

347346
public Response(InferenceServiceResults results) {
348347
this.results = results;
349348
this.isStreaming = false;
350349
this.publisher = null;
351350
}
352351

353-
public Response(InferenceServiceResults results, Flow.Publisher<ChunkedToXContent> publisher) {
352+
public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
354353
this.results = results;
355354
this.isStreaming = true;
356355
this.publisher = publisher;
@@ -434,7 +433,7 @@ public boolean isStreaming() {
434433
* When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
435434
* If the RestResponse is closed, it will cancel the subscription.
436435
*/
437-
public Flow.Publisher<ChunkedToXContent> publisher() {
436+
public Flow.Publisher<InferenceServiceResults.Result> publisher() {
438437
assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
439438
return publisher;
440439
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,43 @@
88
package org.elasticsearch.xpack.core.inference.results;
99

1010
import org.elasticsearch.common.collect.Iterators;
11+
import org.elasticsearch.common.io.stream.StreamInput;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.common.io.stream.Writeable;
1214
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1315
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14-
import org.elasticsearch.inference.InferenceResults;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.xcontent.ToXContent;
1718

1819
import java.io.IOException;
1920
import java.util.Deque;
2021
import java.util.Iterator;
21-
import java.util.List;
22-
import java.util.Map;
2322
import java.util.concurrent.Flow;
2423

24+
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
25+
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
26+
import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
2527
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
2628

2729
/**
2830
* Chat Completion results that only contain a Flow.Publisher.
2931
*/
30-
public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults {
32+
public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
33+
implements
34+
InferenceServiceResults {
3135

3236
@Override
3337
public boolean isStreaming() {
3438
return true;
3539
}
3640

37-
@Override
38-
public List<? extends InferenceResults> transformToCoordinationFormat() {
39-
throw new UnsupportedOperationException("Not implemented");
40-
}
41-
42-
@Override
43-
public List<? extends InferenceResults> transformToLegacyFormat() {
44-
throw new UnsupportedOperationException("Not implemented");
45-
}
41+
public record Results(Deque<Result> results) implements InferenceServiceResults.Result {
42+
public static final String NAME = "streaming_chat_completion_results";
4643

47-
@Override
48-
public Map<String, Object> asMap() {
49-
throw new UnsupportedOperationException("Not implemented");
50-
}
51-
52-
@Override
53-
public String getWriteableName() {
54-
throw new UnsupportedOperationException("Not implemented");
55-
}
56-
57-
@Override
58-
public void writeTo(StreamOutput out) throws IOException {
59-
throw new UnsupportedOperationException("Not implemented");
60-
}
61-
62-
@Override
63-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
64-
throw new UnsupportedOperationException("Not implemented");
65-
}
44+
public Results(StreamInput in) throws IOException {
45+
this(readDeque(in, Result::new));
46+
}
6647

67-
public record Results(Deque<Result> results) implements ChunkedToXContent {
6848
@Override
6949
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
7050
return Iterators.concat(
@@ -75,11 +55,37 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
7555
ChunkedToXContentHelper.endObject()
7656
);
7757
}
58+
59+
@Override
60+
public String getWriteableName() {
61+
return NAME;
62+
}
63+
64+
@Override
65+
public void writeTo(StreamOutput out) throws IOException {
66+
out.writeCollection(results);
67+
}
68+
69+
@Override
70+
public boolean equals(Object o) {
71+
if (o == null || getClass() != o.getClass()) return false;
72+
Results other = (Results) o;
73+
return dequeEquals(this.results, other.results());
74+
}
75+
76+
@Override
77+
public int hashCode() {
78+
return dequeHashCode(results);
79+
}
7880
}
7981

80-
public record Result(String delta) implements ChunkedToXContent {
82+
public record Result(String delta) implements ChunkedToXContent, Writeable {
8183
private static final String RESULT = "delta";
8284

85+
private Result(StreamInput in) throws IOException {
86+
this(in.readString());
87+
}
88+
8389
@Override
8490
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
8591
return Iterators.concat(
@@ -88,5 +94,10 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
8894
ChunkedToXContentHelper.endObject()
8995
);
9096
}
97+
98+
@Override
99+
public void writeTo(StreamOutput out) throws IOException {
100+
out.writeString(delta);
101+
}
91102
}
92103
}

0 commit comments

Comments
 (0)