Skip to content

Commit 7f284d3

Browse files
[ML] Remove tasktype any from supportedStreamingTasks (#121460)
* Refactoring supported streaming functionality * Moving always retrying handler to the tests * Fixing comment
1 parent 4541b12 commit 7f284d3

23 files changed

+69
-67
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl
300300
}
301301

302302
private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
303-
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
303+
if (request.isStreaming() == false || service.canStream(model.getTaskType())) {
304304
doInference(model, request, service, listener);
305305
} else {
306306
listener.onFailure(unsupportedStreamingTaskException(request, service));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,8 @@ public class AnthropicResponseHandler extends BaseResponseHandler {
4444

4545
static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";
4646

47-
private final boolean canHandleStreamingResponses;
48-
4947
public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
50-
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
51-
this.canHandleStreamingResponses = canHandleStreamingResponses;
52-
}
53-
54-
@Override
55-
public boolean canHandleStreamingResponses() {
56-
return canHandleStreamingResponses;
48+
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
5749
}
5850

5951
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,9 @@
3434
public class CohereResponseHandler extends BaseResponseHandler {
3535
static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most";
3636
static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response";
37-
private final boolean canHandleStreamingResponse;
3837

3938
public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) {
40-
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse);
41-
this.canHandleStreamingResponse = canHandleStreamingResponse;
42-
}
43-
44-
@Override
45-
public boolean canHandleStreamingResponses() {
46-
return canHandleStreamingResponse;
39+
super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse, canHandleStreamingResponse);
4740
}
4841

4942
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser
2121
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse);
2222
}
2323

24+
public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
25+
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse, canHandleStreamingResponses);
26+
}
27+
2428
@Override
2529
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
2630
if (result.isSuccessfulResponse()) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@
2020

2121
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
2222
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
23-
super(requestType, parseFunction);
24-
}
25-
26-
@Override
27-
public boolean canHandleStreamingResponses() {
28-
return true;
23+
super(requestType, parseFunction, true);
2924
}
3025

3126
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
public class GoogleAiStudioResponseHandler extends BaseResponseHandler {
2929

3030
static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down";
31-
private final boolean canHandleStreamingResponses;
3231
private final CheckedFunction<XContentParser, String, IOException> content;
3332

3433
public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) {
@@ -44,8 +43,7 @@ public GoogleAiStudioResponseHandler(
4443
boolean canHandleStreamingResponses,
4544
CheckedFunction<XContentParser, String, IOException> content
4645
) {
47-
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse);
48-
this.canHandleStreamingResponses = canHandleStreamingResponses;
46+
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse, canHandleStreamingResponses);
4947
this.content = content;
5048
}
5149

@@ -88,11 +86,6 @@ private static String resourceNotFoundError(Request request) {
8886
return format("Resource not found at [%s]", request.getURI());
8987
}
9088

91-
@Override
92-
public boolean canHandleStreamingResponses() {
93-
return canHandleStreamingResponses;
94-
}
95-
9689
@Override
9790
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
9891
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,27 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3838
protected final String requestType;
3939
private final ResponseParser parseFunction;
4040
private final Function<HttpResult, ErrorResponse> errorParseFunction;
41+
private final boolean canHandleStreamingResponses;
4142

4243
public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
44+
this(requestType, parseFunction, errorParseFunction, false);
45+
}
46+
47+
public BaseResponseHandler(
48+
String requestType,
49+
ResponseParser parseFunction,
50+
Function<HttpResult, ErrorResponse> errorParseFunction,
51+
boolean canHandleStreamingResponses
52+
) {
4353
this.requestType = Objects.requireNonNull(requestType);
4454
this.parseFunction = Objects.requireNonNull(parseFunction);
4555
this.errorParseFunction = Objects.requireNonNull(errorParseFunction);
56+
this.canHandleStreamingResponses = canHandleStreamingResponses;
57+
}
58+
59+
@Override
60+
public boolean canHandleStreamingResponses() {
61+
return canHandleStreamingResponses;
4662
}
4763

4864
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@ public interface ResponseHandler {
5252

5353
/**
5454
* Returns {@code true} if the response handler can handle streaming results, or {@code false} if can only parse the entire payload.
55-
* Defaults to {@code false}.
5655
*/
57-
default boolean canHandleStreamingResponses() {
58-
return false;
59-
}
56+
boolean canHandleStreamingResponses();
6057

6158
/**
6259
* A method for parsing the streamed response from the server. Implementations must invoke the

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,8 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
4141

4242
static final String OPENAI_SERVER_BUSY = "Received a server busy error status code";
4343

44-
private final boolean canHandleStreamingResponses;
45-
4644
public OpenAiResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
47-
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
48-
this.canHandleStreamingResponses = canHandleStreamingResponses;
45+
super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse, canHandleStreamingResponses);
4946
}
5047

5148
/**
@@ -121,11 +118,6 @@ static String buildRateLimitErrorMessage(HttpResult result) {
121118
return RATE_LIMIT + ". " + usageMessage;
122119
}
123120

124-
@Override
125-
public boolean canHandleStreamingResponses() {
126-
return canHandleStreamingResponses;
127-
}
128-
129121
@Override
130122
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
131123
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
1616

1717
public abstract class AmazonBedrockResponseHandler implements ResponseHandler {
18+
19+
@Override
20+
public boolean canHandleStreamingResponses() {
21+
return false;
22+
}
23+
1824
@Override
1925
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
2026
throws RetryException {

0 commit comments

Comments
 (0)