diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 94ac888abdb4e..194feab7b0174 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -92,7 +92,7 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo try (var service = createElasticInferenceService()) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), @@ -128,7 +128,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() try (var service = createElasticInferenceService()) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), is( @@ -203,7 +203,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA try (var service = createElasticInferenceService()) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), is( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 3552fd8cacdf8..2417561cc4497 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -310,7 +310,7 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl } private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + if (request.isStreaming() == false || service.canStream(model.getTaskType())) { doInference(model, request, service, listener); } else { listener.onFailure(unsupportedStreamingTaskException(request, service)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java index 951cf1615683d..7240af8319d9a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java @@ -44,16 +44,8 @@ public class AnthropicResponseHandler extends BaseResponseHandler { static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code"; - private final boolean canHandleStreamingResponses; - public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { - super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse); - this.canHandleStreamingResponses = canHandleStreamingResponses; - } - - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; + super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index 2f260c8c410a6..49d2144b9e2d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -34,16 +34,9 @@ public class CohereResponseHandler extends BaseResponseHandler { static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; - private final boolean canHandleStreamingResponse; public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) { - super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse); - this.canHandleStreamingResponse = canHandleStreamingResponse; - } - - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponse; + super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse, canHandleStreamingResponse); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java index 16feafd185207..e829a9b5f2300 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java @@ -21,6 +21,10 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse); } + public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { + super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse, canHandleStreamingResponses); + } + @Override protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { int statusCode = result.response().getStatusLine().getStatusCode(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index 9e9531af06c8f..a240035468b8a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -27,12 +27,7 @@ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction); - } - - @Override - public boolean canHandleStreamingResponses() { - return true; + super(requestType, parseFunction, true); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java index 944501dfb18f5..3ee36bf942337 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java @@ -28,7 +28,6 @@ public class GoogleAiStudioResponseHandler extends BaseResponseHandler { static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down"; - private final boolean canHandleStreamingResponses; private final CheckedFunction content; public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) { @@ -44,8 +43,7 @@ public GoogleAiStudioResponseHandler( boolean canHandleStreamingResponses, CheckedFunction content ) { - super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse); - this.canHandleStreamingResponses = canHandleStreamingResponses; + super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse, canHandleStreamingResponses); this.content = content; } @@ -88,11 +86,6 @@ private static String resourceNotFoundError(Request request) { return format("Resource not found at [%s]", request.getURI()); } - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; - } - @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 52a2ffba0c36c..cb5ed53fc5587 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -38,11 +38,27 @@ public abstract class BaseResponseHandler implements ResponseHandler { protected final String requestType; private final ResponseParser parseFunction; private final Function errorParseFunction; + private final boolean canHandleStreamingResponses; public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function errorParseFunction) { + this(requestType, parseFunction, errorParseFunction, false); + } + + public BaseResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction, + boolean canHandleStreamingResponses + ) { this.requestType = Objects.requireNonNull(requestType); this.parseFunction = Objects.requireNonNull(parseFunction); this.errorParseFunction = Objects.requireNonNull(errorParseFunction); + this.canHandleStreamingResponses = canHandleStreamingResponses; + } + + @Override + public boolean canHandleStreamingResponses() { + return canHandleStreamingResponses; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java index 35fe241ffae4f..0452391a76023 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java @@ -52,11 +52,8 @@ public interface ResponseHandler { /** * Returns {@code true} if the response handler can handle streaming results, or {@code false} if can only parse the entire payload. - * Defaults to {@code false}. */ - default boolean canHandleStreamingResponses() { - return false; - } + boolean canHandleStreamingResponses(); /** * A method for parsing the streamed response from the server. Implementations must invoke the diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java index 33aabf42b9583..9ade696ab8047 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java @@ -43,8 +43,6 @@ public class OpenAiResponseHandler extends BaseResponseHandler { static final String OPENAI_SERVER_BUSY = "Received a server busy error status code"; - private final boolean canHandleStreamingResponses; - public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { this(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses); } @@ -55,8 +53,7 @@ protected OpenAiResponseHandler( Function errorParseFunction, boolean canHandleStreamingResponses ) { - super(requestType, parseFunction, errorParseFunction); - this.canHandleStreamingResponses = canHandleStreamingResponses; + super(requestType, parseFunction, errorParseFunction, canHandleStreamingResponses); } /** @@ -132,11 +129,6 @@ static String buildRateLimitErrorMessage(HttpResult result) { return RATE_LIMIT + ". " + usageMessage; } - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; - } - @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java index 9dc15ea667c1d..79bb4e6ddb35b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java @@ -15,6 +15,12 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; public abstract class AmazonBedrockResponseHandler implements ResponseHandler { + + @Override + public boolean canHandleStreamingResponses() { + return false; + } + @Override public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 18378ce9f06b2..16fe128b57b2a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -37,7 +37,7 @@ import java.util.Set; public abstract class SenderService implements InferenceService { - protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY); + protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 628822ca5129c..315e21068a539 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -321,10 +321,6 @@ public synchronized Set supportedStreamingTasks() { var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - if (authorizedStreamingTaskTypes.isEmpty() == false) { - authorizedStreamingTaskTypes.add(TaskType.ANY); - } - return authorizedStreamingTaskTypes; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 776d42c58e89d..30973bea16ec5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -375,7 +375,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public Set supportedStreamingTasks() { - return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY); + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java index 0349e858d9b22..92d13019bc944 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java @@ -58,4 +58,8 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t } } + @Override + public boolean canHandleStreamingResponses() { + return false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java index 0100c2812cdc0..7e5b8e6808366 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java @@ -654,6 +654,11 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t public String getRequestType() { return "foo"; } + + @Override + public boolean canHandleStreamingResponses() { + return false; + } }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index ebcba5e3cafa9..c9e06fee7f73c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -57,6 +57,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1381,8 +1382,8 @@ public void testInfer_UnauthorizedResponse() throws IOException { public void testSupportsStreaming() throws IOException { try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 33101a3e02661..f48cf3b9f4852 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -47,6 +47,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -655,8 +656,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index dabdb0ec7a3ad..9f7317a1c151f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -62,6 +62,7 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1481,8 +1482,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 585eae74b87e5..28bf5dba55ad4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -55,6 +55,7 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1558,8 +1559,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 81d5a870e867f..b3f7f10f0f2f3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -59,6 +59,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1705,8 +1706,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 864eef09622a2..30ada1c12f021 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -839,8 +839,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -984,7 +984,8 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); assertThat( service.defaultConfigIds(), is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index e756bad88e843..d2f9343fa62b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -55,6 +55,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1187,8 +1188,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index d44010580d339..8b090cf4c6cfe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -59,6 +59,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1227,8 +1228,8 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } }