diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java index 9535acfc06f49..0f27ec19e97cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.http.retry; +import org.elasticsearch.core.Nullable; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -22,9 +23,9 @@ public class ChatCompletionErrorResponseHandler { private static final String STREAM_ERROR = "stream_error"; - private final UnifiedChatCompletionErrorParser unifiedChatCompletionErrorParser; + private final UnifiedChatCompletionErrorParserContract unifiedChatCompletionErrorParser; - public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParser errorParser) { + public ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorParserContract errorParser) { this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser); } @@ -49,7 +50,7 @@ private UnifiedChatCompletionException buildChatCompletionErrorInternal( restStatus, errorMessage, errorResponse.type(), - errorResponse.code(), + code(errorResponse.code(), restStatus), errorResponse.param() ); } else { @@ -57,6 +58,10 @@ private UnifiedChatCompletionException buildChatCompletionErrorInternal( } } + private static String code(@Nullable String code, RestStatus status) { + return code != null ? code : status.name().toLowerCase(Locale.ROOT); + } + /** * Builds a default {@link UnifiedChatCompletionException} for a streaming request. * This method is used when an error response is received we were unable to parse it in the format we were expecting. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParserContract.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParserContract.java index 60f1c44919ca9..4999ed058daf8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParser.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorParserContract.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.inference.external.http.HttpResult; -public interface UnifiedChatCompletionErrorParser { +public interface UnifiedChatCompletionErrorParserContract { UnifiedChatCompletionErrorResponse parse(HttpResult result); UnifiedChatCompletionErrorResponse parse(String result); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java index 3a70842455f1d..3fd46ee9d7487 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponse.java @@ -8,12 +8,51 @@ package org.elasticsearch.xpack.inference.external.http.retry; import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.inference.external.http.HttpResult; import java.util.Objects; +import java.util.Optional; public class UnifiedChatCompletionErrorResponse extends ErrorResponse { + + // Default for testing + static final ConstructingObjectParser, Void> ERROR_OBJECT_PARSER = + new ConstructingObjectParser<>("streaming_error", true, args -> Optional.ofNullable((UnifiedChatCompletionErrorResponse) args[0])); + private static final ConstructingObjectParser ERROR_BODY_PARSER = + new ConstructingObjectParser<>( + "streaming_error", + true, + args -> new UnifiedChatCompletionErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code")); + ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param")); + + ERROR_OBJECT_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + public static final UnifiedChatCompletionErrorParserContract ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithObjectParser(ERROR_OBJECT_PARSER); public static final UnifiedChatCompletionErrorResponse UNDEFINED_ERROR = new UnifiedChatCompletionErrorResponse(); + /** + * Standard error response parser. + * @param response The error response as an HttpResult + */ + public static UnifiedChatCompletionErrorResponse fromHttpResult(HttpResult response) { + return ERROR_PARSER.parse(response); + } + @Nullable private final String code; @Nullable diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtils.java new file mode 100644 index 0000000000000..a01189d23f820 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtils.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; +import java.util.concurrent.Callable; + +public class UnifiedChatCompletionErrorResponseUtils { + + /** + * Creates a {@link UnifiedChatCompletionErrorParserContract} that parses the error response as a string. + * This is useful for cases where the error response is too complicated to parse. + * + * @param type The type of the error, used for categorization. + * @return A {@link UnifiedChatCompletionErrorParserContract} instance. + */ + public static UnifiedChatCompletionErrorParserContract createErrorParserWithStringify(String type) { + return new UnifiedChatCompletionErrorParserContract() { + @Override + public UnifiedChatCompletionErrorResponse parse(HttpResult result) { + try { + String errorMessage = new String(result.body(), StandardCharsets.UTF_8); + return new UnifiedChatCompletionErrorResponse(errorMessage, type, null, null); + } catch (Exception e) { + // swallow the error + } + + return UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR; + } + + @Override + public UnifiedChatCompletionErrorResponse parse(String result) { + return new UnifiedChatCompletionErrorResponse(result, type, null, null); + } + }; + } + + /** + * Creates a {@link UnifiedChatCompletionErrorParserContract} that uses a {@link ConstructingObjectParser} to parse the error response. + * This is useful for cases where the error response can be parsed into an object. + * + * @param objectParser The {@link ConstructingObjectParser} to use for parsing the error response. + * @return A {@link UnifiedChatCompletionErrorParserContract} instance. + */ + public static UnifiedChatCompletionErrorParserContract createErrorParserWithObjectParser( + ConstructingObjectParser, Void> objectParser + ) { + return new UnifiedChatCompletionErrorParser<>((parser) -> objectParser.apply(parser, null)); + } + + /** + * Creates a {@link UnifiedChatCompletionErrorParserContract} that uses a generic parser function to parse the error response. + * This is useful for cases where the error response can be parsed using custom logic, typically when parsing from a map. + * + * @param genericParser The function that takes an {@link XContentParser} and returns an + * {@link Optional}. + * @param The type of exception that the parser can throw. + * @return A {@link UnifiedChatCompletionErrorParserContract} instance. + */ + public static UnifiedChatCompletionErrorParserContract createErrorParserWithGenericParser( + CheckedFunction, E> genericParser + ) { + return new UnifiedChatCompletionErrorParser<>(genericParser); + } + + private record UnifiedChatCompletionErrorParser( + CheckedFunction, E> genericParser + ) implements UnifiedChatCompletionErrorParserContract { + + @Override + public UnifiedChatCompletionErrorResponse parse(HttpResult result) { + return executeGenericParser(genericParser, createHttpResultXContentParserFunction(result)); + } + + @Override + public UnifiedChatCompletionErrorResponse parse(String result) { + return executeGenericParser(genericParser, createStringXContentParserFunction(result)); + } + + } + + private static Callable createHttpResultXContentParserFunction(HttpResult response) { + return () -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body()); + } + + private static Callable createStringXContentParserFunction(String response) { + return () -> XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response); + } + + private static UnifiedChatCompletionErrorResponse executeGenericParser( + CheckedFunction, E> genericParser, + Callable createXContentParser + ) { + try (XContentParser parser = createXContentParser.call()) { + return genericParser.apply(parser).orElse(UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR; + } + + private UnifiedChatCompletionErrorResponseUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java deleted file mode 100644 index 93e1d6388f357..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.response.streaming; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; -import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; - -import java.util.Objects; -import java.util.Optional; - -/** - * Represents an error response from a streaming inference service. - * This class extends {@link ErrorResponse} and provides additional fields - * specific to streaming errors, such as code, param, and type. - * An example error response for a streaming service might look like: - *

- *     {
- *         "error": {
- *             "message": "Invalid input",
- *             "code": "400",
- *             "param": "input",
- *             "type": "invalid_request_error"
- *         }
- *     }
- * 
- * TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication. - */ -public class StreamingErrorResponse extends ErrorResponse { - private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( - "streaming_error", - true, - args -> Optional.ofNullable((StreamingErrorResponse) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( - "streaming_error", - true, - args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) - ); - - static { - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); - ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code")); - ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param")); - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); - - ERROR_PARSER.declareObjectOrNull( - ConstructingObjectParser.optionalConstructorArg(), - ERROR_BODY_PARSER, - null, - new ParseField("error") - ); - } - - /** - * Standard error response parser. This can be overridden for those subclasses that - * have a different error response structure. - * @param response The error response as an HttpResult - */ - public static ErrorResponse fromResponse(HttpResult response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response.body()) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; - } - - /** - * Standard error response parser. This can be overridden for those subclasses that - * have a different error response structure. - * @param response The error response as a string - */ - public static ErrorResponse fromString(String response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; - } - - @Nullable - private final String code; - @Nullable - private final String param; - private final String type; - - StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { - super(errorMessage); - this.code = code; - this.param = param; - this.type = Objects.requireNonNull(type); - } - - @Nullable - public String code() { - return code; - } - - @Nullable - public String param() { - return param; - } - - public String type() { - return type; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandler.java index b6f90dc92f6df..6bede6f537be7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandler.java @@ -7,18 +7,11 @@ package org.elasticsearch.xpack.inference.services.ai21.completion; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for AI21 inference endpoints. * Adapts the OpenAI handler to support AI21's error schema. @@ -26,43 +19,10 @@ public class Ai21ChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { private static final String AI_21_ERROR = "ai21_error"; + private static final UnifiedChatCompletionErrorParserContract AI_21_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(AI_21_ERROR); public Ai21ChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, ErrorResponse::fromResponse); - } - - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return new UnifiedChatCompletionException(restStatus, errorMessage, AI_21_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); - } - - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = ErrorResponse.fromString(message); - if (errorResponse.errorStructureFound()) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - AI_21_ERROR, - "stream_error" - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - "stream_error" - ); - } + super(requestType, parseFunction, AI_21_ERROR_PARSER::parse, AI_21_ERROR_PARSER); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 382ef333bfe72..6b7a104806f2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ChatCompletionErrorResponseHandler; -import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponse; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; @@ -65,7 +65,7 @@ protected UnifiedChatCompletionException buildError(String message, Request requ return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result); } - private static class GoogleVertexAiErrorParser implements UnifiedChatCompletionErrorParser { + private static class GoogleVertexAiErrorParser implements UnifiedChatCompletionErrorParserContract { @Override public UnifiedChatCompletionErrorResponse parse(HttpResult result) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index c2d25e6ccaf08..81c15f6b1ddcb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -8,26 +8,21 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.core.Nullable; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; +import java.io.IOException; import java.util.Optional; -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints. * Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code". @@ -35,67 +30,11 @@ public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { private static final String HUGGING_FACE_ERROR = "hugging_face_error"; + private static final UnifiedChatCompletionErrorParserContract HUGGING_FACE_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithObjectParser(StreamingHuggingFaceErrorResponseEntity.ERROR_PARSER); public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse); - } - - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof HuggingFaceErrorResponseEntity - ? new UnifiedChatCompletionException( - restStatus, - errorMessage, - HUGGING_FACE_ERROR, - restStatus.name().toLowerCase(Locale.ROOT) - ) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } - - @Override - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); - if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - HUGGING_FACE_ERROR, - extractErrorCode(streamingHuggingFaceErrorResponseEntity) - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - "stream_error" - ); - } - } - - private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null - ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) - : null; + super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse, HUGGING_FACE_ERROR_PARSER); } /** @@ -110,62 +49,69 @@ private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity s * } * */ - private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse { - private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( - HUGGING_FACE_ERROR, - true, - args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = - new ConstructingObjectParser<>( - HUGGING_FACE_ERROR, - true, - args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) - ); + private static class StreamingHuggingFaceErrorResponseEntity extends UnifiedChatCompletionErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = + new ConstructingObjectParser<>(HUGGING_FACE_ERROR, true, args -> { + if (args[0] == null) { + return Optional.empty(); + } - static { - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); - ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code")); + return Optional.of(new StreamingHuggingFaceErrorResponseEntity((ErrorField) args[0])); + }); - ERROR_PARSER.declareObjectOrNull( + static { + ERROR_PARSER.declareField( ConstructingObjectParser.optionalConstructorArg(), - ERROR_BODY_PARSER, - null, - new ParseField("error") + (p, c) -> parseErrorField(p), + new ParseField("error"), + // The expected value is an object, string, or null, using this value type to allow that combination + // We'll check the current token in the called function to ensure it is only an object, string, or null + ObjectParser.ValueType.VALUE_OBJECT_ARRAY ); } - /** - * Parses a streaming HuggingFace error response from a JSON string. - * - * @param response the raw JSON string representing an error - * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails - */ - public static ErrorResponse fromString(String response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; + StreamingHuggingFaceErrorResponseEntity(ErrorField errorField) { + super( + errorField.message, + HUGGING_FACE_ERROR, + errorField.httpStatusCode != null ? String.valueOf(errorField.httpStatusCode) : null, + null + ); } + } - @Nullable - private final Integer httpStatusCode; + private static ErrorField parseErrorField(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.VALUE_STRING) { + return ErrorField.parseString(parser); + } else if (token == XContentParser.Token.START_OBJECT) { + return ErrorField.parseObject(parser); + } else if (token == XContentParser.Token.VALUE_NULL) { + return null; + } - StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { - super(errorMessage); - this.httpStatusCode = httpStatusCode; + throw new XContentParseException("Unexpected token: " + token); + } + + private record ErrorField(String message, @Nullable Integer httpStatusCode) { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ErrorField.class.getSimpleName(), + true, + args -> new ErrorField(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code")); } - @Nullable - public Integer httpStatusCode() { - return httpStatusCode; + public static ErrorField parseObject(XContentParser parser) { + return PARSER.apply(parser, null); } + public static ErrorField parseString(XContentParser parser) throws IOException { + return new ErrorField(parser.text(), null); + } } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java index 570dbd1e709ee..33672cec9244f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java @@ -7,15 +7,11 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity.WATSONX_ERROR_PARSER; /** * Handles streaming chat completion responses and error parsing for Watsonx inference endpoints. @@ -23,29 +19,7 @@ */ public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - private static final String WATSONX_ERROR = "watsonx_error"; - public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse); - } - - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof IbmWatsonxErrorResponseEntity - ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse, WATSONX_ERROR_PARSER); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java index 012283d54be89..883ef4151d321 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxErrorResponseEntity.java @@ -7,38 +7,39 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.response; -import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; +import java.io.IOException; import java.util.Map; import java.util.Objects; +import java.util.Optional; -public class IbmWatsonxErrorResponseEntity extends ErrorResponse { +public class IbmWatsonxErrorResponseEntity extends UnifiedChatCompletionErrorResponse { + private static final String WATSONX_ERROR = "watsonx_error"; + public static final UnifiedChatCompletionErrorParserContract WATSONX_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithGenericParser(IbmWatsonxErrorResponseEntity::doParse); private IbmWatsonxErrorResponseEntity(String errorMessage) { - super(errorMessage); + super(errorMessage, WATSONX_ERROR, null, null); } - @SuppressWarnings("unchecked") - public static ErrorResponse fromResponse(HttpResult response) { - try ( - XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response.body()) - ) { - var responseMap = jsonParser.map(); - var error = (Map) responseMap.get("error"); - if (error != null) { - var message = (String) error.get("message"); - return new IbmWatsonxErrorResponseEntity(Objects.requireNonNullElse(message, "")); - } - } catch (Exception e) { - // swallow the error + public static UnifiedChatCompletionErrorResponse fromResponse(HttpResult result) { + return WATSONX_ERROR_PARSER.parse(result); + } + + private static Optional doParse(XContentParser parser) throws IOException { + var responseMap = parser.map(); + @SuppressWarnings("unchecked") + var error = (Map) responseMap.get("error"); + if (error != null) { + var message = (String) error.get("message"); + return Optional.of(new IbmWatsonxErrorResponseEntity(Objects.requireNonNullElse(message, ""))); } - return ErrorResponse.UNDEFINED_ERROR; + return Optional.of(UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java index c9ff10c307bf2..4187f7ee6112c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java @@ -7,26 +7,11 @@ package org.elasticsearch.xpack.inference.services.llama.completion; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; -import java.util.Optional; - -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for Llama inference endpoints. * This handler is designed to work with the unified Llama chat completion API. @@ -34,85 +19,7 @@ public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { private static final String LLAMA_ERROR = "llama_error"; - private static final String STREAM_ERROR = "stream_error"; - /** - * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser. - * - * @param requestType the type of request this handler will process - * @param parseFunction the function to parse the response - */ - public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, LlamaErrorResponse::fromResponse); - } - - /** - * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type, - * @param message the error message to include in the exception - * @param request the request that caused the error - * @param result the HTTP result containing the response - * @param errorResponse the error response parsed from the HTTP result - * @return an exception representing the error, specific to Llama chat completion - */ - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof LlamaErrorResponse - ? new UnifiedChatCompletionException(restStatus, errorMessage, LLAMA_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } - - /** - * Builds an exception for mid-stream errors encountered during Llama chat completion requests. - * - * @param request the request that caused the error - * @param message the error message - * @param e the exception that occurred, if any - * @return a UnifiedChatCompletionException representing the error - */ - @Override - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message); - if (errorResponse instanceof StreamingLlamaErrorResponseEntity) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - LLAMA_ERROR, - STREAM_ERROR - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - STREAM_ERROR - ); - } - } - - /** - * StreamingLlamaErrorResponseEntity allows creation of {@link ErrorResponse} from a JSON string. - * This entity is used to parse error responses from streaming Llama requests. - * For non-streaming requests {@link LlamaErrorResponse} should be used. * Example error response for Bad Request error would look like: *

      *  {
@@ -121,60 +28,19 @@ protected Exception buildMidStreamError(Request request, String message, Excepti
      *      }
      *  }
      * 
+ * + * This parser will simply convert the entire object into a string. */ - private static class StreamingLlamaErrorResponseEntity extends ErrorResponse { - private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( - LLAMA_ERROR, - true, - args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0]) - ); - private static final ConstructingObjectParser< - LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity, - Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( - LLAMA_ERROR, - true, - args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity( - args[0] != null ? (String) args[0] : "unknown" - ) - ); - - static { - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); - - ERROR_PARSER.declareObjectOrNull( - ConstructingObjectParser.optionalConstructorArg(), - ERROR_BODY_PARSER, - null, - new ParseField("error") - ); - } - - /** - * Parses a streaming Llama error response from a JSON string. - * - * @param response the raw JSON string representing an error - * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails - */ - public static ErrorResponse fromString(String response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } + private static final UnifiedChatCompletionErrorParserContract LLAMA_STREAM_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(LLAMA_ERROR); - return ErrorResponse.UNDEFINED_ERROR; - } - - /** - * Constructs a StreamingLlamaErrorResponseEntity with the specified error message. - * - * @param errorMessage the error message to include in the response entity - */ - StreamingLlamaErrorResponseEntity(String errorMessage) { - super(errorMessage); - } + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LLAMA_STREAM_ERROR_PARSER::parse, LLAMA_STREAM_ERROR_PARSER); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java index 8e3b5b10df900..b1a1d30d8119a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.services.llama.completion; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; /** @@ -24,6 +24,6 @@ public class LlamaCompletionResponseHandler extends OpenAiChatCompletionResponse * @param parseFunction The function to parse the response. */ public LlamaCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + super(requestType, parseFunction, ErrorResponse::fromResponse); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java index 240ccf46c7482..b5a307b384d8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.services.llama.embeddings; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; /** @@ -24,6 +24,6 @@ public class LlamaEmbeddingsResponseHandler extends OpenAiResponseHandler { * @param parseFunction the function to parse the response */ public LlamaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, LlamaErrorResponse::fromResponse, false); + super(requestType, parseFunction, ErrorResponse::fromResponse, false); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java deleted file mode 100644 index 727231209fdf1..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.llama.response; - -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; - -import java.nio.charset.StandardCharsets; - -/** - * LlamaErrorResponse is responsible for handling error responses from Llama inference services. - * It extends ErrorResponse to provide specific functionality for Llama errors. - * An example error response for Not Found error would look like: - *

- *  {
- *      "detail": "Not Found"
- *  }
- * 
- * An example error response for Bad Request error would look like: - *

- *  {
- *     "error": {
- *         "detail": {
- *             "errors": [
- *                 {
- *                     "loc": [
- *                         "body",
- *                         "model"
- *                     ],
- *                     "msg": "Field required",
- *                     "type": "missing"
- *                 }
- *             ]
- *         }
- *     }
- *  }
- * 
- */ -public class LlamaErrorResponse extends ErrorResponse { - - public LlamaErrorResponse(String message) { - super(message); - } - - public static ErrorResponse fromResponse(HttpResult response) { - try { - String errorMessage = new String(response.body(), StandardCharsets.UTF_8); - return new LlamaErrorResponse(errorMessage); - } catch (Exception e) { - // swallow the error - } - return ErrorResponse.UNDEFINED_ERROR; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index 3b7b1b81a2864..5c3315a267ea7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -7,15 +7,11 @@ package org.elasticsearch.xpack.inference.services.mistral; -import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; -import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; -import java.util.Locale; - /** * Handles streaming chat completion responses and error parsing for Mistral inference endpoints. * Adapts the OpenAI handler to support Mistral's error schema. @@ -24,16 +20,16 @@ public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedCh private static final String MISTRAL_ERROR = "mistral_error"; - public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, ErrorResponse::fromResponse); - } + private static final UnifiedChatCompletionErrorParserContract ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils + .createErrorParserWithStringify(MISTRAL_ERROR); - @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + /** + * Constructs a MistralUnifiedChatCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "mistral completions"). + * @param parseFunction The function to parse the response. + */ + public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ERROR_PARSER::parse, ERROR_PARSER); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index 8a70f4428799b..9258fa048f460 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -8,101 +8,77 @@ package org.elasticsearch.xpack.inference.services.openai; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ChatCompletionErrorResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract; +import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponse; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; -import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; -import java.util.Locale; import java.util.concurrent.Flow; import java.util.function.Function; -import static org.elasticsearch.core.Strings.format; - /** * Handles streaming chat completion responses and error parsing for OpenAI inference endpoints. * This handler is designed to work with the unified OpenAI chat completion API. */ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + private static final ChatCompletionErrorResponseHandler DEFAULT_CHAT_COMPLETION_ERROR_RESPONSE_HANDLER = + new ChatCompletionErrorResponseHandler(UnifiedChatCompletionErrorResponse.ERROR_PARSER); + + private final ChatCompletionErrorResponseHandler chatCompletionErrorResponseHandler; + public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, StreamingErrorResponse::fromResponse); + this( + requestType, + parseFunction, + UnifiedChatCompletionErrorResponse::fromHttpResult, + DEFAULT_CHAT_COMPLETION_ERROR_RESPONSE_HANDLER + ); } public OpenAiUnifiedChatCompletionResponseHandler( String requestType, ResponseParser parseFunction, - Function errorParseFunction + Function errorParseFunction, + UnifiedChatCompletionErrorParserContract unifiedChatCompletionErrorParser ) { super(requestType, parseFunction, errorParseFunction); + this.chatCompletionErrorResponseHandler = new ChatCompletionErrorResponseHandler(unifiedChatCompletionErrorParser); + } + + private OpenAiUnifiedChatCompletionResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction, + ChatCompletionErrorResponseHandler chatCompletionErrorResponseHandler + ) { + super(requestType, parseFunction, errorParseFunction); + this.chatCompletionErrorResponseHandler = chatCompletionErrorResponseHandler; } @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> chatCompletionErrorResponseHandler.buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof StreamingErrorResponse oer - ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } - } - - protected static String createErrorType(ErrorResponse errorResponse) { - return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; - } - - protected Exception buildMidStreamError(Request request, String message, Exception e) { - return buildMidStreamError(request.getInferenceEntityId(), message, e); + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result) { + return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result); } public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { - var errorResponse = StreamingErrorResponse.fromString(message); - if (errorResponse instanceof StreamingErrorResponse oer) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - oer.type(), - oer.code(), - oer.param() - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), - createErrorType(errorResponse), - "stream_error" - ); - } + return DEFAULT_CHAT_COMPLETION_ERROR_RESPONSE_HANDLER.buildMidStreamChatCompletionError(inferenceEntityId, message, e); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtilsTests.java new file mode 100644 index 0000000000000..4e1c255528624 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/UnifiedChatCompletionErrorResponseUtilsTests.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.retry; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class UnifiedChatCompletionErrorResponseUtilsTests extends ESTestCase { + + public void testCreateErrorParserWithStringify() { + var parser = UnifiedChatCompletionErrorResponseUtils.createErrorParserWithStringify("test"); + var errorResponseJson = """ + { + "error": "A valid user token is required" + } + """; + var errorResponse = parser.parse(new HttpResult(mock(HttpResponse.class), errorResponseJson.getBytes(StandardCharsets.UTF_8))); + assertNotNull(errorResponse); + assertEquals(errorResponseJson, errorResponse.getErrorMessage()); + } + + public void testCreateErrorParserWithObjectParser() { + var objectParser = UnifiedChatCompletionErrorResponse.ERROR_OBJECT_PARSER; + + var parser = UnifiedChatCompletionErrorResponseUtils.createErrorParserWithObjectParser(objectParser); + var errorResponseJson = """ + { + "error": { + "message": "A valid user token is required", + "type": "invalid_request_error", + "code": "code", + "param": "param" + } + } + """; + var errorResponse = parser.parse(new HttpResult(mock(HttpResponse.class), errorResponseJson.getBytes(StandardCharsets.UTF_8))); + assertThat(errorResponse.getErrorMessage(), is("A valid user token is required")); + assertThat(errorResponse.code(), is("code")); + assertThat(errorResponse.param(), is("param")); + } + + public void testCreateErrorParserWithGenericParser() { + var parser = UnifiedChatCompletionErrorResponseUtils.createErrorParserWithGenericParser( + UnifiedChatCompletionErrorResponseUtilsTests::doParse + ); + var errorResponseJson = """ + { + "error": { + "message": "A valid user token is required", + "type": "invalid_request_error", + "code": "code", + "param": "param" + } + } + """; + var errorResponse = parser.parse(errorResponseJson); + assertThat(errorResponse.getErrorMessage(), is("A valid user token is required")); + assertThat(errorResponse.code(), is("code")); + assertThat(errorResponse.param(), is("param")); + } + + private static Optional doParse(XContentParser parser) throws IOException { + var responseMap = parser.map(); + @SuppressWarnings("unchecked") + var error = (Map) responseMap.get("error"); + if (error != null) { + var message = (String) error.get("message"); + return Optional.of( + new UnifiedChatCompletionErrorResponse( + Objects.requireNonNullElse(message, ""), + "test", + (String) error.get("code"), + (String) error.get("param") + ) + ); + } + + return Optional.of(UnifiedChatCompletionErrorResponse.UNDEFINED_ERROR); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java index a23fc09df5865..88e0ea3287336 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/Ai21ServiceTests.java @@ -359,7 +359,6 @@ public void testMidStreamUnifiedCompletionError() throws Exception { testStreamError(XContentHelper.stripWhitespace(""" { "error": { - "code": "stream_error", "message": "Received an error response for request from inference entity id [id].\ Error message: [{\\"error\\": {\\"message\\": \\"400: Invalid value: Model 'ai213.12:3b' not found\\"}}]", "type": "ai21_error" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java index dc0ef6a480caf..ef404c253f156 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java @@ -51,6 +51,25 @@ public void testFailValidationWithAllFields() throws IOException { "type":"hugging_face_error"}}""")); } + public void testFailValidationWithErrorAsObject() throws IOException { + var responseJson = """ + { + "error": { + "message": "a message" + }, + "type": "validation" + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"bad_request","message":"Received a server error status code for request from \ + inference entity id [id] status [500]. \ + Error message: [a message]",\ + "type":"hugging_face_error"}}""")); + } + public void testFailValidationWithoutOptionalFields() throws IOException { var responseJson = """ { @@ -75,7 +94,7 @@ public void testFailValidationWithInvalidJson() throws IOException { assertThat(errorJson, is(""" {"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\ - [500]","type":"ErrorResponse"}}""")); + [500]","type":"UnifiedChatCompletionErrorResponse"}}""")); } private String invalidResponseJson(String responseJson) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index dd68c43f5e62d..442058171bf50 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -481,9 +481,8 @@ public void testMidStreamUnifiedCompletionError() throws Exception { testStreamError(XContentHelper.stripWhitespace(""" { "error": { - "code": "stream_error", "message": "Received an error response for request from inference entity id [id].\ - Error message: [400: Invalid value: Model 'llama3.12:3b' not found]", + Error message: [{\\"error\\": {\\"message\\": \\"400: Invalid value: Model 'llama3.12:3b' not found\\"}}]", "type": "llama_error" } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java deleted file mode 100644 index aa3c6f6c20b6e..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.llama.response; - -import org.apache.http.HttpResponse; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.HttpResult; - -import java.nio.charset.StandardCharsets; - -import static org.mockito.Mockito.mock; - -public class LlamaErrorResponseTests extends ESTestCase { - - public static final String ERROR_RESPONSE_JSON = """ - { - "error": "A valid user token is required" - } - """; - - public void testFromResponse() { - var errorResponse = LlamaErrorResponse.fromResponse( - new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) - ); - assertNotNull(errorResponse); - assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); - } - -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index a64d7326a62b1..602378f2b9783 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -344,7 +344,7 @@ public void testUnifiedCompletionInfer() throws Exception { } } - public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + public void testUnifiedCompletionStreamingNotFoundError() throws Exception { String responseJson = """ { "detail": "Not Found" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java index 61d0be92b2ee0..4b3543b6cd9ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java @@ -65,9 +65,19 @@ public void testFailValidationWithoutOptionalFields() throws IOException { var errorJson = invalidResponseJson(responseJson); - assertThat(errorJson, is(""" - {"error":{"message":"Received a server error status code for request from inference entity id [abc] status [500]. \ - Error message: [a message]","type":"not_found_error"}}""")); + @SuppressWarnings("checkstyle:LineLength") + var expectedError = XContentHelper.stripWhitespace( + """ + { + "error":{ + "code":"bad_request", + "message":"Received a server error status code for request from inference entity id [abc] status [500]. Error message: [a message]", + "type":"not_found_error" + } + }""" + ); + + assertThat(errorJson, is(expectedError)); } public void testFailValidationWithInvalidJson() throws IOException { @@ -79,7 +89,7 @@ public void testFailValidationWithInvalidJson() throws IOException { assertThat(errorJson, is(""" {"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [abc] status\ - [500]","type":"ErrorResponse"}}""")); + [500]","type":"UnifiedChatCompletionErrorResponse"}}""")); } private String invalidResponseJson(String responseJson) throws IOException {