Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;
Expand All @@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
*
* <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
*/
List<? extends InferenceResults> transformToCoordinationFormat();
default List<? extends InferenceResults> transformToCoordinationFormat() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
}

/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
*/
List<? extends InferenceResults> transformToLegacyFormat();
default List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
}

/**
* Convert the result to a map to aid with test assertions
*/
Map<String, Object> asMap();
default Map<String, Object> asMap() {
throw new UnsupportedOperationException("asMap() is not implemented");
}

default String getWriteableName() {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

default void writeTo(StreamOutput out) throws IOException {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

default Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

/**
* Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
Expand All @@ -52,8 +77,10 @@ default boolean isStreaming() {
* When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
* Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
*/
default Flow.Publisher<? extends ChunkedToXContent> publisher() {
default Flow.Publisher<? extends Result> publisher() {
assert isStreaming() == false : "This must be implemented when isStreaming() == true";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
}

interface Result extends NamedWriteable, ChunkedToXContent {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -342,15 +341,15 @@ public static class Response extends ActionResponse implements ChunkedToXContent

private final InferenceServiceResults results;
private final boolean isStreaming;
private final Flow.Publisher<ChunkedToXContent> publisher;
private final Flow.Publisher<InferenceServiceResults.Result> publisher;

public Response(InferenceServiceResults results) {
this.results = results;
this.isStreaming = false;
this.publisher = null;
}

public Response(InferenceServiceResults results, Flow.Publisher<ChunkedToXContent> publisher) {
public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
this.results = results;
this.isStreaming = true;
this.publisher = publisher;
Expand Down Expand Up @@ -434,7 +433,7 @@ public boolean isStreaming() {
* When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
* If the RestResponse is closed, it will cancel the subscription.
*/
public Flow.Publisher<ChunkedToXContent> publisher() {
public Flow.Publisher<InferenceServiceResults.Result> publisher() {
assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
return publisher;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,45 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;

/**
* Chat Completion results that only contain a Flow.Publisher.
*/
public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults {
public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
implements
InferenceServiceResults {

@Override
public boolean isStreaming() {
return true;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Not implemented");
}
public record Results(Deque<Result> results) implements InferenceServiceResults.Result {
public static final String NAME = "streaming_chat_completion_results";

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public String getWriteableName() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new UnsupportedOperationException("Not implemented");
}
public Results(StreamInput in) throws IOException {
this(deque(in));
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException("Not implemented");
}
private static Deque<Result> deque(StreamInput in) throws IOException {
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(new Result(stream))));
}

public record Results(Deque<Result> results) implements ChunkedToXContent {
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
Expand All @@ -75,14 +57,66 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
ChunkedToXContentHelper.endObject()
);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(results);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Results other = (Results) o;
return dequeEquals(this.results, other.results());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah was this why you were saying deque was a bad idea haha?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it took me a few minutes to realize most of the Deque implementations do not have an equals method =(

}

private static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
if (thisDeque.size() != otherDeque.size()) {
return false;
}
var thisIter = thisDeque.iterator();
var otherIter = otherDeque.iterator();
while (thisIter.hasNext() && otherIter.hasNext()) {
if (thisIter.next().equals(otherIter.next()) == false) {
return false;
}
}
return true;
}

@Override
public int hashCode() {
return dequeHashCode(results);
}

private static int dequeHashCode(Deque<?> deque) {
if (deque == null) {
return 0;
}
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
}
}

public record Result(String delta) implements ChunkedToXContent {
public record Result(String delta) implements ChunkedToXContent, Writeable {
private static final String RESULT = "delta";

private Result(StreamInput in) throws IOException {
this(in.readString());
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field(RESULT, delta).endObject());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(delta);
}
}
}
Loading