Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo

try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));

assertThat(
service.defaultConfigIds(),
Expand Down Expand Up @@ -128,7 +128,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty()
try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
Expand Down Expand Up @@ -203,7 +203,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
try (var service = createElasticInferenceService()) {
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl
}

private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
if (request.isStreaming() == false || service.canStream(model.getTaskType())) {
doInference(model, request, service, listener);
} else {
listener.onFailure(unsupportedStreamingTaskException(request, service));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,8 @@ public class AnthropicResponseHandler extends BaseResponseHandler {

static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";

private final boolean canHandleStreamingResponses;

public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,9 @@
public class CohereResponseHandler extends BaseResponseHandler {
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
private final boolean canHandleStreamingResponse;

public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
this.canHandleStreamingResponse = canHandleStreamingResponse;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponse;
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse, canHandleStreamingResponse);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse);
}

public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse, canHandleStreamingResponses);
}

@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction);
}

@Override
public boolean canHandleStreamingResponses() {
return true;
super(requestType, parseFunction, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
public class GoogleAiStudioResponseHandler extends BaseResponseHandler {

static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down";
private final boolean canHandleStreamingResponses;
private final CheckedFunction<XContentParser, String, IOException> content;

public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) {
Expand All @@ -44,8 +43,7 @@ public GoogleAiStudioResponseHandler(
boolean canHandleStreamingResponses,
CheckedFunction<XContentParser, String, IOException> content
) {
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse);
this.canHandleStreamingResponses = canHandleStreamingResponses;
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse, canHandleStreamingResponses);
this.content = content;
}

Expand Down Expand Up @@ -88,11 +86,6 @@ private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,27 @@ public abstract class BaseResponseHandler implements ResponseHandler {
protected final String requestType;
private final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
this(requestType, parseFunction, errorParseFunction, false);
}

public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
this.requestType = Objects.requireNonNull(requestType);
this.parseFunction = Objects.requireNonNull(parseFunction);
this.errorParseFunction = Objects.requireNonNull(errorParseFunction);
this.canHandleStreamingResponses = canHandleStreamingResponses;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@ public interface ResponseHandler {

/**
* Returns {@code true} if the response handler can handle streaming results, or {@code false} if can only parse the entire payload.
* Defaults to {@code false}.
*/
default boolean canHandleStreamingResponses() {
return false;
}
boolean canHandleStreamingResponses();

/**
* A method for parsing the streamed response from the server. Implementations must invoke the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ public class OpenAiResponseHandler extends BaseResponseHandler {

static final String OPENAI_SERVER_BUSY = "Received a server busy error status code";

private final boolean canHandleStreamingResponses;

public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
this(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
}
Expand All @@ -55,8 +53,7 @@ protected OpenAiResponseHandler(
Function<HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
super(requestType, parseFunction, errorParseFunction);
this.canHandleStreamingResponses = canHandleStreamingResponses;
super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses);
}

/**
Expand Down Expand Up @@ -132,11 +129,6 @@ static String buildRateLimitErrorMessage(HttpResult result) {
return RATE_LIMIT + ". " + usageMessage;
}

@Override
public boolean canHandleStreamingResponses() {
return canHandleStreamingResponses;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

public abstract class AmazonBedrockResponseHandler implements ResponseHandler {

@Override
public boolean canHandleStreamingResponses() {
return false;
}

@Override
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import java.util.Set;

public abstract class SenderService implements InferenceService {
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
private final Sender sender;
private final ServiceComponents serviceComponents;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,6 @@ public synchronized Set<TaskType> supportedStreamingTasks() {
var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION);
authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());

if (authorizedStreamingTaskTypes.isEmpty() == false) {
authorizedStreamingTaskTypes.add(TaskType.ANY);
}

return authorizedStreamingTaskTypes;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY);
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,8 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t
}
}

@Override
public boolean canHandleStreamingResponses() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,11 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t
public String getRequestType() {
return "foo";
}

@Override
public boolean canHandleStreamingResponses() {
return false;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1381,8 +1382,8 @@ public void testInfer_UnauthorizedResponse() throws IOException {

public void testSupportsStreaming() throws IOException {
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -655,8 +656,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1481,8 +1482,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1558,8 +1559,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.junit.Before;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1705,8 +1706,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
ensureAuthorizationCallFinished(service);

assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
assertTrue(service.defaultConfigIds().isEmpty());

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
Expand Down Expand Up @@ -984,7 +984,8 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) {
ensureAuthorizationCallFinished(service);
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY)));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
assertThat(
service.defaultConfigIds(),
is(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1187,8 +1188,8 @@ public void testGetConfiguration() throws Exception {

public void testSupportsStreaming() throws IOException {
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
assertTrue(service.canStream(TaskType.COMPLETION));
assertTrue(service.canStream(TaskType.ANY));
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}

Expand Down
Loading