Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,13 @@ public String getRequestType() {
}

@Override
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) {
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);

if (checkForErrorObject) {
// When the response is streamed the status code could be 200 but the error object will be set
// so we need to check for that specifically
checkForErrorObject(request, result);
}
}

protected abstract void checkForFailureStatusCode(Request request, HttpResult result);

protected void checkForErrorObject(Request request, HttpResult result) {
var errorEntity = errorParseFunction.apply(result);

if (errorEntity.errorStructureFound()) {
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
// client retry if necessary
// If we did want to retry here, we'll need to determine if this was a streaming request, if it was
// we shouldn't retry because that would replay the entire streaming request and the client would get
// duplicate chunks back
throw new RetryException(false, buildError(SERVER_ERROR_OBJECT, request, result, errorEntity));
}
}

protected Exception buildError(String message, Request request, HttpResult result) {
var errorEntityMsg = errorParseFunction.apply(result);
return buildError(message, request, result, errorEntityMsg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser error
this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser);
}

public void checkForErrorObject(Request request, HttpResult result) {
var errorEntity = unifiedChatCompletionErrorParser.parse(result);

if (errorEntity.errorStructureFound()) {
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
// client retry if necessary
// If we did want to retry here, we'll need to determine if this was a streaming request, if it was
// we shouldn't retry because that would replay the entire streaming request and the client would get
// duplicate chunks back
throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity));
}
}

public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) {
var errorResponse = unifiedChatCompletionErrorParser.parse(result);
return buildChatCompletionErrorInternal(message, request, result, errorResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ public interface ResponseHandler {
* @param logger the logger to use for logging
* @param request the original request
* @param result the response from the server
* @param checkForErrorObject if true, the validation function should check for the presence of an error object even if the status code
* indicates a success
* @throws RetryException if the response is invalid
*/
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result, boolean checkForErrorObject)
throws RetryException;
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException;

/**
* A method for parsing the response from the server.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
} else {
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
ll.onResponse(inferenceResults);
} catch (Exception e) {
Expand All @@ -134,7 +134,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
} else {
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, r, false);
responseHandler.validateResponse(throttlerManager, logger, request, r);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);

l.onResponse(inferenceResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@ public boolean canHandleStreamingResponses() {
}

@Override
public final void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
// do nothing as the AWS SDK will take care of validation for us
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,8 @@ public AzureMistralOpenAiExternalResponseHandler(
}

@Override
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ protected UnifiedChatCompletionException buildError(String message, Request requ
return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result);
}

@Override
protected void checkForErrorObject(Request request, HttpResult result) {
chatCompletionErrorResponseHandler.checkForErrorObject(request, result);
}

private static class GoogleVertexAiErrorParser implements UnifiedChatCompletionErrorParser {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,8 @@ public AlwaysRetryingResponseHandler(
}

@Override
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
try {
checkForFailureStatusCode(throttlerManager, logger, request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ public void testValidateResponse_DoesNotThrowAnExceptionWhenStatus200_AndNoError
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
);
}

public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() {
public void testValidateResponse_DoesNotThrowError_WhenStatus200_AndMalformedErrorObject() {
var handler = getBaseResponseHandler();

String responseJson = """
Expand All @@ -80,25 +79,15 @@ public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() {
var request = mock(Request.class);
when(request.getInferenceEntityId()).thenReturn("abc");

var exception = expectThrows(
RetryException.class,
() -> handler.validateResponse(
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);

assertFalse(exception.shouldRetry());
assertThat(
exception.getCause().getMessage(),
is("Received an error response for request from inference entity id [abc] status [200]")
handler.validateResponse(
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
);
}

public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
public void testValidateResponse_DoesNotThrow_WhenStatus200_AndWellFormedErrorObjectExists() {
var handler = getBaseResponseHandler();

String responseJson = """
Expand All @@ -115,21 +104,11 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
var request = mock(Request.class);
when(request.getInferenceEntityId()).thenReturn("abc");

var exception = expectThrows(
RetryException.class,
() -> handler.validateResponse(
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);

assertFalse(exception.shouldRetry());
assertThat(
exception.getCause().getMessage(),
is("Received an error response for request from inference entity id [abc] status [200]. Error message: [a message]")
handler.validateResponse(
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
);
}

Expand All @@ -154,8 +133,7 @@ public void testValidateResponse_DoesNot_ThrowErrorWhenWellFormedErrorObjectExis
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
false
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -77,7 +76,7 @@ public void testSend_CallsSenderAgain_AfterValidateResponseThrowsAnException() t
Answer<InferenceServiceResults> answer = (invocation) -> inferenceResults;

var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any(), anyBoolean());
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any());
// Mockito.thenReturn() does not compile when returning a
// bounded wild card list, thenAnswer must be used instead.
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
Expand Down Expand Up @@ -352,7 +351,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnException_AfterO
var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doThrow(new IllegalStateException("failed again"))
.when(handler)
.validateResponse(any(), any(), any(), any(), anyBoolean());
.validateResponse(any(), any(), any(), any());
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);

var retrier = createRetrier(sender);
Expand Down Expand Up @@ -389,7 +388,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnElasticsearchExc
var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doThrow(new RetryException(false, "failed again"))
.when(handler)
.validateResponse(any(), any(), any(), any(), anyBoolean());
.validateResponse(any(), any(), any(), any());
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);

var retrier = createRetrier(httpClient);
Expand Down Expand Up @@ -702,13 +701,8 @@ private ResponseHandler createRetryingResponseHandler() {
// testing failed requests
return new ResponseHandler() {
@Override
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
throw new RetryException(true, new IOException("response handler validate failed as designed"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ private Exception invalidResponse(String responseJson, int statusCode) {
mock(),
mock(),
mockRequest(),
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ private Exception invalidResponse(String responseJson) {
mock(),
mock(),
mockRequest(),
new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ private Exception invalidResponse(String responseJson) {
mock(),
mock(),
mockRequest(),
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ private Exception invalidResponse(String responseJson, int statusCode) {
mock(),
mock(),
mockRequest(),
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ private Exception invalidResponse(String responseJson, int statusCode) {
mock(),
mock(),
mockRequest(),
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ private Exception invalidResponse(String responseJson) {
mock(),
mock(),
mockRequest(),
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
true
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
}
Expand Down