Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.Locale;
Expand Down Expand Up @@ -140,47 +141,27 @@ protected Exception buildError(String message, Request request, HttpResult resul
* @param request the request that caused the error
* @param result the HTTP result containing the error response
* @param errorResponse the parsed error response from the HTTP result
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
*/
protected UnifiedChatCompletionException buildChatCompletionError(
String message,
Request request,
HttpResult result,
ErrorResponse errorResponse,
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
ChatCompletionErrorBuilder chatCompletionErrorBuilder
StreamingErrorResponse errorResponse
) {
assert request.isStreaming() : "Only streaming requests support this format";
var statusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode);
var restStatus = toRestStatus(statusCode);

return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClassSupplier, chatCompletionErrorBuilder);
}

/**
* Builds a {@link UnifiedChatCompletionException} for a streaming request.
* This method is used when an error response is received from the external service.
* Only streaming requests should use this method.
*
* @param errorResponse the error response parsed from the HTTP result
* @param errorMessage the error message to include in the exception
* @param restStatus the REST status code of the response
* @param errorResponseClassSupplier the supplier that provides the class of the expected error response type
* @param chatCompletionErrorBuilder the builder for creating provider-specific chat completion errors
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
*/
protected UnifiedChatCompletionException buildChatCompletionError(
ErrorResponse errorResponse,
String errorMessage,
RestStatus restStatus,
Supplier<Class<? extends ErrorResponse>> errorResponseClassSupplier,
ChatCompletionErrorBuilder chatCompletionErrorBuilder
) {
if (errorResponse.errorStructureFound() && errorResponseClassSupplier.get().isInstance(errorResponse)) {
return chatCompletionErrorBuilder.buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus);
if (errorResponse.errorStructureFound()) {
return new UnifiedChatCompletionException(
restStatus,
errorMessage,
errorResponse.type(),
errorResponse.code(),
errorResponse.param()
);
} else {
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
}
Expand All @@ -196,7 +177,7 @@ protected UnifiedChatCompletionException buildChatCompletionError(
* @param restStatus the REST status code of the response
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
*/
private static UnifiedChatCompletionException buildDefaultChatCompletionError(
protected static UnifiedChatCompletionException buildDefaultChatCompletionError(
ErrorResponse errorResponse,
String errorMessage,
RestStatus restStatus
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public static ErrorResponse fromString(String response) {
private final String param;
private final String type;

StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
protected StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Flow;

Expand Down Expand Up @@ -71,29 +71,11 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR

@Override
protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
return buildChatCompletionError(
message,
request,
result,
errorResponse,
() -> GoogleVertexAiErrorResponse.class,
GoogleVertexAiUnifiedChatCompletionResponseHandler::buildProviderSpecificChatCompletionError
);
}

private static UnifiedChatCompletionException buildProviderSpecificChatCompletionError(
ErrorResponse errorResponse,
String errorMessage,
RestStatus restStatus
) {
var vertexAIErrorResponse = (GoogleVertexAiErrorResponse) errorResponse;
return new UnifiedChatCompletionException(
restStatus,
errorMessage,
vertexAIErrorResponse.status(),
String.valueOf(vertexAIErrorResponse.code()),
null
);
if (errorResponse instanceof StreamingErrorResponse streamingErrorResponse) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't go through the exercise to prove this out but I suspect that most of the *UnifiedChatCompletionResponseHandlers will have the exact same code as this implementation. So maybe we could move the contents into the BaseResponseHandler as a new method and have the buildError() override here have one line that simply class back into the base class.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responded in original PR:
elastic#128923 (comment)

return buildChatCompletionError(message, request, result, streamingErrorResponse);
} else {
return buildDefaultChatCompletionError(errorResponse, message, toRestStatus(result.response().getStatusLine().getStatusCode()));
}
}

private static UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError(
Expand All @@ -109,13 +91,13 @@ private static UnifiedChatCompletionException buildProviderSpecificMidStreamChat
inferenceEntityId,
errorResponse.getErrorMessage()
),
vertexAIErrorResponse.status(),
String.valueOf(vertexAIErrorResponse.code()),
vertexAIErrorResponse.type(),
vertexAIErrorResponse.code(),
null
);
}

public static class GoogleVertexAiErrorResponse extends ErrorResponse {
public static class GoogleVertexAiErrorResponse extends StreamingErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"google_vertex_ai_error_wrapper",
true,
Expand Down Expand Up @@ -153,7 +135,7 @@ public static ErrorResponse fromResponse(HttpResult response) {
}
}

static ErrorResponse fromString(String response) {
public static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
Expand All @@ -164,23 +146,8 @@ static ErrorResponse fromString(String response) {
}
}

private final int code;
@Nullable
private final String status;

GoogleVertexAiErrorResponse(Integer code, String errorMessage, @Nullable String status) {
super(Objects.requireNonNull(errorMessage));
this.code = code == null ? 0 : code;
this.status = status;
}

public int code() {
return code;
}

@Nullable
public String status() {
return status != null ? status : "google_vertex_ai_error";
super(errorMessage, code == null ? "0" : String.valueOf(code), null, status != null ? status : "google_vertex_ai_error");
}
}
}