diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 23117d0daa35a..bcfeef9f4af99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -300,7 +300,7 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl } private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + if (request.isStreaming() == false || service.canStream(model.getTaskType())) { doInference(model, request, service, listener); } else { listener.onFailure(unsupportedStreamingTaskException(request, service)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java index d9a78a56af0d6..915b0bf412a03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java @@ -44,16 +44,8 @@ public class AnthropicResponseHandler extends BaseResponseHandler { static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code"; - private final boolean canHandleStreamingResponses; - public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { - super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse); - this.canHandleStreamingResponses = canHandleStreamingResponses; - } - - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; + super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index e3a74785caa4b..9227d55dc8938 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -34,16 +34,9 @@ public class CohereResponseHandler extends BaseResponseHandler { static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; - private final boolean canHandleStreamingResponse; public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) { - super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse); - this.canHandleStreamingResponse = canHandleStreamingResponse; - } - - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponse; + super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse, canHandleStreamingResponse); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java index b11b4a743fb27..bd34e746cb2f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java @@ -21,6 +21,10 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse); } + public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { + super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse, canHandleStreamingResponses); + } + @Override protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { if (result.isSuccessfulResponse()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index c0bccb9b2cd49..e1438dde76c91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -20,12 +20,7 @@ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction); - } - - @Override - public boolean canHandleStreamingResponses() { - return true; + super(requestType, parseFunction, true); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java index d61e82cb83b45..a22be46bf7576 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java @@ -28,7 +28,6 @@ public class GoogleAiStudioResponseHandler extends BaseResponseHandler { static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down"; - private final boolean canHandleStreamingResponses; private final CheckedFunction content; public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) { @@ -44,8 +43,7 @@ public GoogleAiStudioResponseHandler( boolean canHandleStreamingResponses, CheckedFunction content ) { - super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse); - this.canHandleStreamingResponses = canHandleStreamingResponses; + super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse, canHandleStreamingResponses); this.content = content; } @@ -88,11 +86,6 @@ private static String resourceNotFoundError(Request request) { return format("Resource not found at [%s]", request.getURI()); } - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; - } - @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 1b0dd893ada6f..ed852e5177ac0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -38,11 +38,27 @@ public abstract class BaseResponseHandler implements ResponseHandler { protected final String requestType; private final ResponseParser parseFunction; private final Function errorParseFunction; + private final boolean canHandleStreamingResponses; public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function errorParseFunction) { + this(requestType, parseFunction, errorParseFunction, false); + } + + public BaseResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction, + boolean canHandleStreamingResponses + ) { this.requestType = Objects.requireNonNull(requestType); this.parseFunction = Objects.requireNonNull(parseFunction); this.errorParseFunction = Objects.requireNonNull(errorParseFunction); + this.canHandleStreamingResponses = canHandleStreamingResponses; + } + + @Override + public boolean canHandleStreamingResponses() { + return canHandleStreamingResponses; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java index 35fe241ffae4f..0452391a76023 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java @@ -52,11 +52,8 @@ public interface ResponseHandler { /** * Returns {@code true} if the response handler can handle streaming results, or {@code false} if can only parse the entire payload. - * Defaults to {@code false}. */ - default boolean canHandleStreamingResponses() { - return false; - } + boolean canHandleStreamingResponses(); /** * A method for parsing the streamed response from the server. Implementations must invoke the diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java index cf867fb1a0ab0..e0bc341fc6792 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java @@ -41,11 +41,8 @@ public class OpenAiResponseHandler extends BaseResponseHandler { static final String OPENAI_SERVER_BUSY = "Received a server busy error status code"; - private final boolean canHandleStreamingResponses; - public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) { - super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse); - this.canHandleStreamingResponses = canHandleStreamingResponses; + super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses); } /** @@ -121,11 +118,6 @@ static String buildRateLimitErrorMessage(HttpResult result) { return RATE_LIMIT + ". " + usageMessage; } - @Override - public boolean canHandleStreamingResponses() { - return canHandleStreamingResponses; - } - @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java index 9dc15ea667c1d..79bb4e6ddb35b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java @@ -15,6 +15,12 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; public abstract class AmazonBedrockResponseHandler implements ResponseHandler { + + @Override + public boolean canHandleStreamingResponses() { + return false; + } + @Override public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 9695dbf0d210c..56bf6c1359a56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -37,7 +37,7 @@ import java.util.Set; public abstract class SenderService implements InferenceService { - protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY); + protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION); private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8b8723b54d683..ea95f121ca6a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -248,10 +248,6 @@ public synchronized Set supportedStreamingTasks() { var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); authorizedStreamingTaskTypes.retainAll(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes()); - if (authorizedStreamingTaskTypes.isEmpty() == false) { - authorizedStreamingTaskTypes.add(TaskType.ANY); - } - return authorizedStreamingTaskTypes; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 8a420a62d1bce..94312a39882fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -376,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public Set supportedStreamingTasks() { - return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY); + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java index 0349e858d9b22..92d13019bc944 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java @@ -58,4 +58,8 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t } } + @Override + public boolean canHandleStreamingResponses() { + return false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java index 0100c2812cdc0..7e5b8e6808366 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java @@ -654,6 +654,11 @@ public InferenceServiceResults parseResult(Request request, HttpResult result) t public String getRequestType() { return "foo"; } + + @Override + public boolean canHandleStreamingResponses() { + return false; + } }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index ec41388684df1..6505c280c295a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -57,6 +57,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1381,8 +1382,8 @@ public void testInfer_UnauthorizedResponse() throws IOException { public void testSupportsStreaming() throws IOException { try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 33101a3e02661..f48cf3b9f4852 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -47,6 +47,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -655,8 +656,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 77ed889fc7361..cebea7901b956 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -61,6 +61,7 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1471,8 +1472,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 3500f11b199af..e67a5dac0e7c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -54,6 +54,7 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1548,8 +1549,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index b1c5e02fb6f51..90e5dc6890c45 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -58,6 +58,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1686,8 +1687,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index b46fd4941e6f6..fdf8520b939f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -795,7 +795,8 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { service.waitForAuthorizationToComplete(TIMEOUT); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); PlainActionFuture> listener = new PlainActionFuture<>(); @@ -932,7 +933,8 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { service.waitForAuthorizationToComplete(TIMEOUT); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); assertThat( service.defaultConfigIds(), is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 26dae5d172fb0..d0760a583df29 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -54,6 +54,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1174,8 +1175,8 @@ public void testGetConfiguration() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 50c028fab28dd..ee93677538b33 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -56,6 +56,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1133,8 +1134,8 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { public void testSupportsStreaming() throws IOException { try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { - assertTrue(service.canStream(TaskType.COMPLETION)); - assertTrue(service.canStream(TaskType.ANY)); + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); } }