Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl
}

private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
if (request.isStreaming() == false || service.canStream(model.getTaskType())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the model's task type to determine streaming (can never be any).

doInference(model, request, service, listener);
} else {
listener.onFailure(unsupportedStreamingTaskException(request, service));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,8 @@ public class AnthropicResponseHandler extends BaseResponseHandler {

static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";

private final boolean canHandleStreamingResponses;

public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,9 @@
public class CohereResponseHandler extends BaseResponseHandler {
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
private final boolean canHandleStreamingResponse;

public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
this.canHandleStreamingResponse = canHandleStreamingResponse;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponse;
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse, canHandleStreamingResponse);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse);
}

public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction);
}

@Override
public boolean canHandleStreamingResponses() {
return true;
super(requestType, parseFunction, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
public class GoogleAiStudioResponseHandler extends BaseResponseHandler {

static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down";
private final boolean canHandleStreamingResponses;
private final CheckedFunction<XContentParser, String, IOException> content;

public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) {
Expand All @@ -44,8 +43,7 @@ public GoogleAiStudioResponseHandler(
boolean canHandleStreamingResponses,
CheckedFunction<XContentParser, String, IOException> content
) {
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse, canHandleStreamingResponses);
this.content = content;
}

Expand Down Expand Up @@ -88,11 +86,6 @@ private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,27 @@ public abstract class BaseResponseHandler implements ResponseHandler {
protected final String requestType;
private final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
this(requestType, parseFunction, errorParseFunction, false);
}

public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
this.requestType = Objects.requireNonNull(requestType);
this.parseFunction = Objects.requireNonNull(parseFunction);
this.errorParseFunction = Objects.requireNonNull(errorParseFunction);
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved here to the base class.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@ public interface ResponseHandler {

/**
* Returns {@code true} if the response handler can handle streaming results, or {@code false} if can only parse the entire payload.
* Defaults to {@code false}.
*/
default boolean canHandleStreamingResponses() {
return false;
}
boolean canHandleStreamingResponses();

/**
* A method for parsing the streamed response from the server. Implementations must invoke the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ public class OpenAiResponseHandler extends BaseResponseHandler {

static final String OPENAI_SERVER_BUSY = "Received a server busy error status code";

private final boolean canHandleStreamingResponses;

public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
}

/**
Expand Down Expand Up @@ -121,11 +118,6 @@ static String buildRateLimitErrorMessage(HttpResult result) {
return RATE_LIMIT + ". " + usageMessage;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

public abstract class AmazonBedrockResponseHandler implements ResponseHandler {

@Override
public boolean canHandleStreamingResponses() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unfortunate. This response handler is a workaround and isn't really used because this service uses its own client. AmazonBedrock does support streaming the responses but its handled in a separate place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh we need to separate Bedrock from everything after Sender anyway, since it has its own internal threadpool and all the stuff we wrap around the http client

return false;
}

@Override
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import java.util.Set;

public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
private final Sender sender;
private final ServiceComponents serviceComponents;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,6 @@ public synchronized Set<TaskType> supportedStreamingTasks() {
var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION);
authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());

if (authorizedStreamingTaskTypes.isEmpty() == false) {
authorizedStreamingTaskTypes.add(TaskType.ANY);
}

return authorizedStreamingTaskTypes;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY);
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,8 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t
}
}

@Override
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this class to the test directory since that's the only place it's referenced.

public boolean canHandleStreamingResponses() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,11 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t
public String getRequestType() {
return "foo";
}

@Override
public boolean canHandleStreamingResponses() {
return false;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1381,8 +1382,8 @@ public void testInfer_UnauthorizedResponse() throws IOException {

public void testSupportsStreaming() throws IOException {
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -655,8 +656,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1471,8 +1472,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1548,8 +1549,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1686,8 +1687,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
assertTrue(service.defaultConfigIds().isEmpty());

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -932,7 +933,8 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
service.waitForAuthorizationToComplete(TIMEOUT);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
assertThat(
service.defaultConfigIds(),
is(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1174,8 +1175,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1133,8 +1134,8 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down