Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;
Expand All @@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
*
* <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
*/
List<? extends InferenceResults> transformToCoordinationFormat();
default List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
}

/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
*/
List<? extends InferenceResults> transformToLegacyFormat();
default List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
}

/**
* Convert the result to a map to aid with test assertions
*/
Map<String, Object> asMap();
default Map<String, Object> asMap() {
throw new UnsupportedOperationException("asMap() is not implemented");
}

default String getWriteableName() {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

default void writeTo(StreamOutput out) throws IOException {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

default Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
assert isStreaming() : "This must be implemented when isStreaming() == false";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
}

/**
* Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
Expand All @@ -52,8 +77,10 @@ default boolean isStreaming() {
* When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
* Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
*/
default Flow.Publisher<? extends ChunkedToXContent> publisher() {
default Flow.Publisher<? extends Result> publisher() {
assert isStreaming() == false : "This must be implemented when isStreaming() == true";
throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
}

interface Result extends NamedWriteable, ChunkedToXContent {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;

public final class DequeUtils {

private DequeUtils() {
// util functions only
}

public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
}

public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
if (thisDeque.size() != otherDeque.size()) {
return false;
}
var thisIter = thisDeque.iterator();
var otherIter = otherDeque.iterator();
while (thisIter.hasNext() && otherIter.hasNext()) {
if (thisIter.next().equals(otherIter.next()) == false) {
return false;
}
}
return true;
}

public static int dequeHashCode(Deque<?> deque) {
if (deque == null) {
return 0;
}
return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
}

public static <T> Deque<T> of(T elem) {
var deque = new ArrayDeque<T>(1);
deque.offer(elem);
return deque;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.TimeValue;
Expand Down Expand Up @@ -342,15 +341,15 @@ public static class Response extends ActionResponse implements ChunkedToXContent

private final InferenceServiceResults results;
private final boolean isStreaming;
private final Flow.Publisher<ChunkedToXContent> publisher;
private final Flow.Publisher<InferenceServiceResults.Result> publisher;

public Response(InferenceServiceResults results) {
this.results = results;
this.isStreaming = false;
this.publisher = null;
}

public Response(InferenceServiceResults results, Flow.Publisher<ChunkedToXContent> publisher) {
public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
this.results = results;
this.isStreaming = true;
this.publisher = publisher;
Expand Down Expand Up @@ -434,7 +433,7 @@ public boolean isStreaming() {
* When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
* If the RestResponse is closed, it will cancel the subscription.
*/
public Flow.Publisher<ChunkedToXContent> publisher() {
public Flow.Publisher<InferenceServiceResults.Result> publisher() {
assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
return publisher;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,43 @@
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Flow;

import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;

/**
* Chat Completion results that only contain a Flow.Publisher.
*/
public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults {
public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
implements
InferenceServiceResults {

@Override
public boolean isStreaming() {
return true;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Not implemented");
}
public record Results(Deque<Result> results) implements InferenceServiceResults.Result {
public static final String NAME = "streaming_chat_completion_results";

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public String getWriteableName() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException("Not implemented");
}
public Results(StreamInput in) throws IOException {
this(readDeque(in, Result::new));
}

public record Results(Deque<Result> results) implements ChunkedToXContent {
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
Expand All @@ -75,11 +55,37 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
ChunkedToXContentHelper.endObject()
);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(results);
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Results other = (Results) o;
return dequeEquals(this.results, other.results());
}

@Override
public int hashCode() {
return dequeHashCode(results);
}
}

public record Result(String delta) implements ChunkedToXContent {
public record Result(String delta) implements ChunkedToXContent, Writeable {
private static final String RESULT = "delta";

private Result(StreamInput in) throws IOException {
this(in.readString());
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return Iterators.concat(
Expand All @@ -88,5 +94,10 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
ChunkedToXContentHelper.endObject()
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(delta);
}
}
}
Loading