Skip to content

Commit 9992edc

Browse files
authored
[ML] Fix stream support for TaskType.ANY (#115656)
If we support one, then we support any.
1 parent 232622a commit 9992edc

File tree

9 files changed

+56
-1
lines changed

9 files changed

+56
-1
lines changed

docs/changelog/115656.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115656
2+
summary: Fix stream support for `TaskType.ANY`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2626

2727
import java.io.IOException;
28+
import java.util.EnumSet;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Objects;
3132
import java.util.Set;
3233

3334
public abstract class SenderService implements InferenceService {
34-
protected static final Set<TaskType> COMPLETION_ONLY = Set.of(TaskType.COMPLETION);
35+
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
3536
private final Sender sender;
3637
private final ServiceComponents serviceComponents;
3738

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,13 @@ public void testInfer_UnauthorizedResponse() throws IOException {
13041304
}
13051305
}
13061306

1307+
public void testSupportsStreaming() throws IOException {
1308+
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
1309+
assertTrue(service.canStream(TaskType.COMPLETION));
1310+
assertTrue(service.canStream(TaskType.ANY));
1311+
}
1312+
}
1313+
13071314
public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() throws IOException {
13081315
var model = AmazonBedrockEmbeddingsModelTests.createModel(
13091316
"id",

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
593593
.hasErrorContaining("blah");
594594
}
595595

596+
public void testSupportsStreaming() throws IOException {
597+
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
598+
assertTrue(service.canStream(TaskType.COMPLETION));
599+
assertTrue(service.canStream(TaskType.ANY));
600+
}
601+
}
602+
596603
private AnthropicService createServiceWithMockSender() {
597604
return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
598605
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
13841384
.hasErrorContaining("You didn't provide an API key...");
13851385
}
13861386

1387+
public void testSupportsStreaming() throws IOException {
1388+
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
1389+
assertTrue(service.canStream(TaskType.COMPLETION));
1390+
assertTrue(service.canStream(TaskType.ANY));
1391+
}
1392+
}
1393+
13871394
// ----------------------------------------------------------------
13881395

13891396
private AzureAiStudioService createService() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
15041504
.hasErrorContaining("You didn't provide an API key...");
15051505
}
15061506

1507+
public void testSupportsStreaming() throws IOException {
1508+
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
1509+
assertTrue(service.canStream(TaskType.COMPLETION));
1510+
assertTrue(service.canStream(TaskType.ANY));
1511+
}
1512+
}
1513+
15071514
private AzureOpenAiService createAzureOpenAiService() {
15081515
return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
15091516
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
16831683
.hasErrorContaining("how dare you");
16841684
}
16851685

1686+
public void testSupportsStreaming() throws IOException {
1687+
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
1688+
assertTrue(service.canStream(TaskType.COMPLETION));
1689+
assertTrue(service.canStream(TaskType.ANY));
1690+
}
1691+
}
1692+
16861693
private Map<String, Object> getRequestConfigMap(
16871694
Map<String, Object> serviceSettings,
16881695
Map<String, Object> taskSettings,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,13 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
12191219
}
12201220
}
12211221

1222+
public void testSupportsStreaming() throws IOException {
1223+
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
1224+
assertTrue(service.canStream(TaskType.COMPLETION));
1225+
assertTrue(service.canStream(TaskType.ANY));
1226+
}
1227+
}
1228+
12221229
public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
12231230
return Map.of(
12241231
ChatCompletionResults.COMPLETION,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
10771077
.hasErrorContaining("You didn't provide an API key...");
10781078
}
10791079

1080+
public void testSupportsStreaming() throws IOException {
1081+
try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) {
1082+
assertTrue(service.canStream(TaskType.COMPLETION));
1083+
assertTrue(service.canStream(TaskType.ANY));
1084+
}
1085+
}
1086+
10801087
public void testCheckModelConfig_IncludesMaxTokens() throws IOException {
10811088
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
10821089

0 commit comments

Comments
 (0)