diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index cb5ed53fc5587..3dac8d849ba6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -76,13 +76,21 @@ public String getRequestType() { } @Override - public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) { + public void validateResponse( + ThrottlerManager throttlerManager, + Logger logger, + Request request, + HttpResult result, + boolean checkForErrorObject + ) { checkForFailureStatusCode(request, result); checkForEmptyBody(throttlerManager, logger, request, result); - // When the response is streamed the status code could be 200 but the error object will be set - // so we need to check for that specifically - checkForErrorObject(request, result); + if (checkForErrorObject) { + // When the response is streamed the status code could be 200 but the error object will be set + // so we need to check for that specifically + checkForErrorObject(request, result); + } } protected abstract void checkForFailureStatusCode(Request request, HttpResult result); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java index 0452391a76023..24ad0132c576b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java @@ -29,9 +29,12 @@ public interface ResponseHandler { * @param logger the logger to use for logging * @param request the original request * @param result the response from the server + * @param checkForErrorObject if true, the validation function should check for the presence of an error object even if the status code + * indicates a success * @throws RetryException if the response is invalid */ - void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException; + void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result, boolean checkForErrorObject) + throws RetryException; /** * A method for parsing the response from the server. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index b71887ce6018f..d009ec87d5776 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -121,7 +121,7 @@ public void tryAction(ActionListener listener) { } else { r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> { try { - responseHandler.validateResponse(throttlerManager, logger, request, httpResult); + responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true); InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult); ll.onResponse(inferenceResults); } catch (Exception e) { @@ -134,7 +134,7 @@ public void tryAction(ActionListener listener) { } else { httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> { try { - responseHandler.validateResponse(throttlerManager, logger, request, r); + responseHandler.validateResponse(throttlerManager, logger, request, r, false); InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r); l.onResponse(inferenceResults); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java index cef709d001e22..73a9b6d570a8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java @@ -22,8 +22,13 @@ public boolean canHandleStreamingResponses() { } @Override - public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) - throws RetryException { + public final void validateResponse( + ThrottlerManager throttlerManager, + Logger logger, + Request request, + HttpResult result, + boolean checkForErrorObject + ) throws RetryException { // do nothing as the AWS SDK will take care of validation for us } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java index cf4030f541d2a..003b0f9e87720 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java @@ -63,8 +63,13 @@ public AzureMistralOpenAiExternalResponseHandler( } @Override - public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) - throws RetryException { + public void validateResponse( + ThrottlerManager throttlerManager, + Logger logger, + Request request, + HttpResult result, + boolean checkForErrorObject + ) throws RetryException { checkForFailureStatusCode(request, result); checkForEmptyBody(throttlerManager, logger, request, result); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java index 92d13019bc944..cb70b2a020f6f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java @@ -35,8 +35,14 @@ public AlwaysRetryingResponseHandler( this.parseFunction = Objects.requireNonNull(parseFunction); } - public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) - throws RetryException { + @Override + public void validateResponse( + ThrottlerManager throttlerManager, + Logger logger, + Request request, + HttpResult result, + boolean checkForErrorObject + ) throws RetryException { try { checkForFailureStatusCode(throttlerManager, logger, request, result); checkForEmptyBody(throttlerManager, logger, request, result); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java index 444a187261fff..4d3957097a969 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java @@ -59,7 +59,8 @@ public void testValidateResponse_DoesNotThrowAnExceptionWhenStatus200_AndNoError mock(ThrottlerManager.class), mock(Logger.class), request, - new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)), + true ); } @@ -85,7 +86,8 @@ public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() { mock(ThrottlerManager.class), mock(Logger.class), request, - new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)), + true ) ); @@ -119,7 +121,8 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() { mock(ThrottlerManager.class), mock(Logger.class), request, - new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)), + true ) ); @@ -130,6 +133,32 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() { ); } + public void testValidateResponse_DoesNot_ThrowErrorWhenWellFormedErrorObjectExists_WhenCheckForErrorIsFalse() { + var handler = getBaseResponseHandler(); + + String responseJson = """ + { + "error": { + "type": "not_found_error", + "message": "a message" + } + } + """; + + var response = mock200Response(); + + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("abc"); + + handler.validateResponse( + mock(ThrottlerManager.class), + mock(Logger.class), + request, + new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)), + false + ); + } + private static HttpResponse mock200Response() { int statusCode = 200; var statusLine = mock(StatusLine.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java index 96401f0140813..5a79f5e0ce798 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java @@ -42,6 +42,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -76,7 +77,7 @@ public void testSend_CallsSenderAgain_AfterValidateResponseThrowsAnException() t Answer answer = (invocation) -> inferenceResults; var handler = mock(ResponseHandler.class); - doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any()); + doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any(), anyBoolean()); // Mockito.thenReturn() does not compile when returning a // bounded wild card list, thenAnswer must be used instead. when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer); @@ -351,7 +352,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnException_AfterO var handler = mock(ResponseHandler.class); doThrow(new RetryException(true, "failed")).doThrow(new IllegalStateException("failed again")) .when(handler) - .validateResponse(any(), any(), any(), any()); + .validateResponse(any(), any(), any(), any(), anyBoolean()); when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer); var retrier = createRetrier(sender); @@ -388,7 +389,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnElasticsearchExc var handler = mock(ResponseHandler.class); doThrow(new RetryException(true, "failed")).doThrow(new RetryException(false, "failed again")) .when(handler) - .validateResponse(any(), any(), any(), any()); + .validateResponse(any(), any(), any(), any(), anyBoolean()); when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer); var retrier = createRetrier(httpClient); @@ -701,8 +702,13 @@ private ResponseHandler createRetryingResponseHandler() { // testing failed requests return new ResponseHandler() { @Override - public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) - throws RetryException { + public void validateResponse( + ThrottlerManager throttlerManager, + Logger logger, + Request request, + HttpResult result, + boolean checkForErrorObject + ) throws RetryException { throw new RetryException(true, new IOException("response handler validate failed as designed")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java index 61d0be92b2ee0..c0c319b47b688 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java @@ -96,7 +96,8 @@ private Exception invalidResponse(String responseJson) { mock(), mock(), mockRequest(), - new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)) + new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)), + true ) ); }