Skip to content

Commit c321fcc

Browse files
authored
[ML] Fix stream support for TaskType.ANY (#115656) (#115864)
If we support one, then we support any.
1 parent ff59f6e commit c321fcc

File tree

9 files changed

+56
-1
lines changed

9 files changed

+56
-1
lines changed

docs/changelog/115656.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115656
2+
summary: Fix stream support for `TaskType.ANY`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2626

2727
import java.io.IOException;
28+
import java.util.EnumSet;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Objects;
3132
import java.util.Set;
3233

3334
public abstract class SenderService implements InferenceService {
34-
protected static final Set<TaskType> COMPLETION_ONLY = Set.of(TaskType.COMPLETION);
35+
protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY);
3536
private final Sender sender;
3637
private final ServiceComponents serviceComponents;
3738

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,13 @@ public void testInfer_UnauthorizedResponse() throws IOException {
12061206
}
12071207
}
12081208

1209+
public void testSupportsStreaming() throws IOException {
1210+
try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) {
1211+
assertTrue(service.canStream(TaskType.COMPLETION));
1212+
assertTrue(service.canStream(TaskType.ANY));
1213+
}
1214+
}
1215+
12091216
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
12101217
var model = AmazonBedrockEmbeddingsModelTests.createModel(
12111218
"id",

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
595595
.hasErrorContaining("blah");
596596
}
597597

598+
public void testSupportsStreaming() throws IOException {
599+
try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) {
600+
assertTrue(service.canStream(TaskType.COMPLETION));
601+
assertTrue(service.canStream(TaskType.ANY));
602+
}
603+
}
604+
598605
private AnthropicService createServiceWithMockSender() {
599606
return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
600607
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
12751275
.hasErrorContaining("You didn't provide an API key...");
12761276
}
12771277

1278+
public void testSupportsStreaming() throws IOException {
1279+
try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) {
1280+
assertTrue(service.canStream(TaskType.COMPLETION));
1281+
assertTrue(service.canStream(TaskType.ANY));
1282+
}
1283+
}
1284+
12781285
// ----------------------------------------------------------------
12791286

12801287
private AzureAiStudioService createService() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
14031403
.hasErrorContaining("You didn't provide an API key...");
14041404
}
14051405

1406+
public void testSupportsStreaming() throws IOException {
1407+
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) {
1408+
assertTrue(service.canStream(TaskType.COMPLETION));
1409+
assertTrue(service.canStream(TaskType.ANY));
1410+
}
1411+
}
1412+
14061413
private AzureOpenAiService createAzureOpenAiService() {
14071414
return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
14081415
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
15821582
.hasErrorContaining("how dare you");
15831583
}
15841584

1585+
public void testSupportsStreaming() throws IOException {
1586+
try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) {
1587+
assertTrue(service.canStream(TaskType.COMPLETION));
1588+
assertTrue(service.canStream(TaskType.ANY));
1589+
}
1590+
}
1591+
15851592
private Map<String, Object> getRequestConfigMap(
15861593
Map<String, Object> serviceSettings,
15871594
Map<String, Object> taskSettings,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,13 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
11211121
}
11221122
}
11231123

1124+
public void testSupportsStreaming() throws IOException {
1125+
try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) {
1126+
assertTrue(service.canStream(TaskType.COMPLETION));
1127+
assertTrue(service.canStream(TaskType.ANY));
1128+
}
1129+
}
1130+
11241131
public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
11251132
return Map.of(
11261133
ChatCompletionResults.COMPLETION,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception {
986986
.hasErrorContaining("You didn't provide an API key...");
987987
}
988988

989+
public void testSupportsStreaming() throws IOException {
990+
try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) {
991+
assertTrue(service.canStream(TaskType.COMPLETION));
992+
assertTrue(service.canStream(TaskType.ANY));
993+
}
994+
}
995+
989996
public void testCheckModelConfig_IncludesMaxTokens() throws IOException {
990997
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
991998

0 commit comments

Comments
 (0)