Skip to content

Commit 573e494

Browse files
authored
[ML] Fix stream support for TaskType.ANY (#115656) (#115865)
If we support one, then we support any.
1 parent 0a586fc commit 573e494

File tree

9 files changed

+56
-1
lines changed

9 files changed

+56
-1
lines changed

docs/changelog/115656.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115656
2+
summary: Fix stream support for `TaskType.ANY`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2626

2727
import java.io.IOException;
28+
import java.util.EnumSet;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Objects;
3132
import java.util.Set;
3233

3334
public abstract class SenderService implements InferenceService {
34-
protected static final Set<TaskType> COMPLETION_ONLY = Set.of(TaskType.COMPLETION);
35+
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
3536
private final Sender sender;
3637
private final ServiceComponents serviceComponents;
3738

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,13 @@ public void testInfer_UnauthorizedResponse() throws IOException {
12061206
}
12071207
}
12081208

1209+
public void testSupportsStreaming() throws IOException {
1210+
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
1211+
assertTrue(service.canStream(TaskType.COMPLETION));
1212+
assertTrue(service.canStream(TaskType.ANY));
1213+
}
1214+
}
1215+
12091216
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
12101217
var model = AmazonBedrockEmbeddingsModelTests.createModel(
12111218
"id",

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
593593
.hasErrorContaining("blah");
594594
}
595595

596+
public void testSupportsStreaming() throws IOException {
597+
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
598+
assertTrue(service.canStream(TaskType.COMPLETION));
599+
assertTrue(service.canStream(TaskType.ANY));
600+
}
601+
}
602+
596603
private AnthropicService createServiceWithMockSender() {
597604
return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
598605
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
12731273
.hasErrorContaining("You didn't provide an API key...");
12741274
}
12751275

1276+
public void testSupportsStreaming() throws IOException {
1277+
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
1278+
assertTrue(service.canStream(TaskType.COMPLETION));
1279+
assertTrue(service.canStream(TaskType.ANY));
1280+
}
1281+
}
1282+
12761283
// ----------------------------------------------------------------
12771284

12781285
private AzureAiStudioService createService() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
14011401
.hasErrorContaining("You didn't provide an API key...");
14021402
}
14031403

1404+
public void testSupportsStreaming() throws IOException {
1405+
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
1406+
assertTrue(service.canStream(TaskType.COMPLETION));
1407+
assertTrue(service.canStream(TaskType.ANY));
1408+
}
1409+
}
1410+
14041411
private AzureOpenAiService createAzureOpenAiService() {
14051412
return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
14061413
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
15801580
.hasErrorContaining("how dare you");
15811581
}
15821582

1583+
public void testSupportsStreaming() throws IOException {
1584+
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
1585+
assertTrue(service.canStream(TaskType.COMPLETION));
1586+
assertTrue(service.canStream(TaskType.ANY));
1587+
}
1588+
}
1589+
15831590
private Map<String, Object> getRequestConfigMap(
15841591
Map<String, Object> serviceSettings,
15851592
Map<String, Object> taskSettings,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,13 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
11211121
}
11221122
}
11231123

1124+
public void testSupportsStreaming() throws IOException {
1125+
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
1126+
assertTrue(service.canStream(TaskType.COMPLETION));
1127+
assertTrue(service.canStream(TaskType.ANY));
1128+
}
1129+
}
1130+
11241131
public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
11251132
return Map.of(
11261133
ChatCompletionResults.COMPLETION,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
984984
.hasErrorContaining("You didn't provide an API key...");
985985
}
986986

987+
public void testSupportsStreaming() throws IOException {
988+
try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) {
989+
assertTrue(service.canStream(TaskType.COMPLETION));
990+
assertTrue(service.canStream(TaskType.ANY));
991+
}
992+
}
993+
987994
public void testCheckModelConfig_IncludesMaxTokens() throws IOException {
988995
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
989996

0 commit comments

Comments
 (0)