From 7e9a46360b063f0d7faf13ba5e430454f27a3ce5 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 2 Jun 2025 12:56:46 -0400 Subject: [PATCH 1/4] Remove error parsing class --- .../services/custom/CustomRequestManager.java | 6 +- .../custom/CustomResponseHandler.java | 13 +- .../services/custom/CustomService.java | 3 +- .../custom/CustomServiceSettings.java | 35 +---- .../custom/response/ErrorResponseParser.java | 128 --------------- .../services/custom/CustomModelTests.java | 4 +- .../custom/CustomRequestManagerTests.java | 4 +- .../custom/CustomServiceSettingsTests.java | 105 ++----------- .../services/custom/CustomServiceTests.java | 18 +-- .../custom/request/CustomRequestTests.java | 19 +-- .../response/ErrorResponseParserTests.java | 148 ------------------ 11 files changed, 45 insertions(+), 438 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index a112e7db26fe3..fa1e753ec5eff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) { } } - private static ResponseHandler createCustomHandler(CustomModel model) { - return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser()); + private static ResponseHandler createCustomHandler() { + return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse); } public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) { @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) private CustomRequestManager(CustomModel model, ThreadPool threadPool) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); this.model = model; - this.handler = createCustomHandler(model); + this.handler = createCustomHandler(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java index 14a962b112ccd..98d4dd9ab4c2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java @@ -12,17 +12,24 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; + +import java.nio.charset.StandardCharsets; +import java.util.function.Function; /** * Defines how to handle various response types returned from the custom integration. */ public class CustomResponseHandler extends BaseResponseHandler { - public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) { - super(requestType, parseFunction, errorParser); + private static final Function ERROR_PARSER = (httpResult) -> new ErrorResponse( + new String(httpResult.body(), StandardCharsets.UTF_8) + ); + + public CustomResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ERROR_PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 5e9aef099f622..77b852f43cd8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom serviceSettings.getQueryParameters(), serviceSettings.getRequestContentString(), serviceSettings.getResponseJsonParser(), - serviceSettings.rateLimitSettings(), - serviceSettings.getErrorParser() + serviceSettings.rateLimitSettings() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index bcad5554e5915..b89f767551103 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; @@ -60,7 +59,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser public static final String REQUEST_CONTENT = "content"; public static final String RESPONSE = "response"; public static final String JSON_PARSER = "json_parser"; - public static final String ERROR_PARSER = "error_parser"; private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); @@ -108,15 +106,6 @@ public static CustomServiceSettings fromMap( var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException); - Map errorParserMap = extractRequiredMap( - Objects.requireNonNullElse(responseParserMap, new HashMap<>()), - ERROR_PARSER, - RESPONSE_SCOPE, - validationException - ); - - var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, DEFAULT_RATE_LIMIT_SETTINGS, @@ -125,14 +114,13 @@ public static CustomServiceSettings fromMap( context ); - if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null || errorParserMap == null) { + if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null) { throw validationException; } throwIfNotEmptyMap(requestBodyMap, REQUEST, NAME); throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME); throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME); - throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -145,8 +133,7 @@ public static CustomServiceSettings fromMap( queryParams, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } @@ -218,7 +205,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final String requestContentString; private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; - private final ErrorResponseParser errorParser; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -227,8 +213,7 @@ public CustomServiceSettings( @Nullable QueryParameters queryParameters, String requestContentString, CustomResponseParser responseJsonParser, - @Nullable RateLimitSettings rateLimitSettings, - ErrorResponseParser errorParser + @Nullable RateLimitSettings rateLimitSettings ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -237,7 +222,6 @@ public CustomServiceSettings( this.requestContentString = Objects.requireNonNull(requestContentString); this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); - this.errorParser = Objects.requireNonNull(errorParser); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -248,7 +232,6 @@ public CustomServiceSettings(StreamInput in) throws IOException { requestContentString = in.readString(); responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); - errorParser = new ErrorResponseParser(in); } @Override @@ -296,10 +279,6 @@ public CustomResponseParser getResponseJsonParser() { return responseJsonParser; } - public ErrorResponseParser getErrorParser() { - return errorParser; - } - @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; @@ -344,7 +323,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder builder.startObject(RESPONSE); { responseJsonParser.toXContent(builder, params); - errorParser.toXContent(builder, params); } builder.endObject(); @@ -372,7 +350,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(requestContentString); out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); - errorParser.writeTo(out); } @Override @@ -386,8 +363,7 @@ public boolean equals(Object o) { && Objects.equals(queryParameters, that.queryParameters) && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) - && Objects.equals(rateLimitSettings, that.rateLimitSettings) - && Objects.equals(errorParser, that.errorParser); + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override @@ -399,8 +375,7 @@ public int hashCode() { queryParameters, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java deleted file mode 100644 index 51fb8b1486a82..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.custom.response; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.common.MapPathExtractor; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; -import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER; -import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType; - -public class ErrorResponseParser implements ToXContentFragment, Function { - - private static final Logger logger = LogManager.getLogger(ErrorResponseParser.class); - public static final String MESSAGE_PATH = "path"; - - private final String messagePath; - private final String inferenceId; - - public static ErrorResponseParser fromMap( - Map responseParserMap, - String scope, - String inferenceId, - ValidationException validationException - ) { - var path = extractRequiredString(responseParserMap, MESSAGE_PATH, String.join(".", scope, ERROR_PARSER), validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new ErrorResponseParser(path, inferenceId); - } - - public ErrorResponseParser(String messagePath, String inferenceId) { - this.messagePath = Objects.requireNonNull(messagePath); - this.inferenceId = Objects.requireNonNull(inferenceId); - } - - public ErrorResponseParser(StreamInput in) throws IOException { - this.messagePath = in.readString(); - this.inferenceId = in.readString(); - } - - public void writeTo(StreamOutput out) throws IOException { - out.writeString(messagePath); - out.writeString(inferenceId); - } - - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(ERROR_PARSER); - { - builder.field(MESSAGE_PATH, messagePath); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ErrorResponseParser that = (ErrorResponseParser) o; - return Objects.equals(messagePath, that.messagePath); - } - - @Override - public int hashCode() { - return Objects.hash(messagePath); - } - - @Override - public ErrorResponse apply(HttpResult httpResult) { - try ( - XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) - ) { - var map = jsonParser.map(); - // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic - // if we find the top level error field we'll return a response with an empty message but indicate - // that we found the structure of the error object. Here if we're missing the final field we will return - // a ErrorResponse.UNDEFINED_ERROR which will indicate that we did not find the structure even if for example - // the outer error field does exist, but it doesn't contain the nested field we were looking for. - // If in the future we want the previous behavior, we can add a new message_path field or something and have - // the current path field point to the field that indicates whether we found an error object. - var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath); - return new ErrorResponse(errorText); - } catch (Exception e) { - var resultAsString = new String(httpResult.body(), StandardCharsets.UTF_8); - - logger.info( - Strings.format( - "Failed to parse error object for custom service inference id [%s], message path: [%s], result as string: [%s]", - inferenceId, - messagePath, - resultAsString - ), - e - ); - - return new ErrorResponse(Strings.format("Unable to parse the error, response body: [%s]", resultAsString)); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index c3c4a44bcab07..c4688534a44d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r QueryParameters.EMPTY, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java index 16c058b3e0115..bbaf9d168aeed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString; @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 1bb3d44b897c8..5781565c0c025 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; @@ -74,8 +73,6 @@ public static CustomServiceSettings createRandom(String inputUrl) { default -> new NoopResponseParser(); }; - var errorParser = new ErrorResponseParser("$.error.message", randomAlphaOfLength(5)); - RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); return new CustomServiceSettings( @@ -90,8 +87,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { queryParameters, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } @@ -133,9 +129,7 @@ public void testFromMap() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -160,8 +154,7 @@ public void testFromMap() { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -186,9 +179,7 @@ public void testFromMap_WithOptionalsNotSpecified() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -208,8 +199,7 @@ public void testFromMap_WithOptionalsNotSpecified() { null, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -250,9 +240,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -277,8 +265,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { null, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -311,9 +298,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -360,9 +345,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -400,9 +383,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -440,9 +421,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -458,8 +437,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { is( "Validation Failed: 1: [service_settings] does not contain the required setting [response];" + "2: [service_settings.response] does not contain the required setting [json_parser];" - + "3: [service_settings.response] does not contain the required setting [error_parser];" - + "4: Encountered a null input map while parsing field [path];" ) ); } @@ -482,9 +459,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsNotEmptyAfterParsing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -527,9 +502,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { "key", "value" ) - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -568,8 +541,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")), "key", "value" ) @@ -591,46 +562,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { ); } - public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() { - String url = "http://www.abc.com"; - String requestContentString = "request body"; - - var mapSettings = new HashMap( - Map.of( - CustomServiceSettings.URL, - url, - CustomServiceSettings.HEADERS, - new HashMap<>(Map.of("key", "value")), - CustomServiceSettings.REQUEST, - new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)), - CustomServiceSettings.RESPONSE, - new HashMap<>( - Map.of( - CustomServiceSettings.JSON_PARSER, - new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) - ) - ) - ) - ); - - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") - ); - - assertThat( - exception.getMessage(), - is( - "Configuration contains unknown settings [{key=value}] while parsing field [error_parser]" - + " for settings [custom_service_settings]" - ) - ); - } - public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { String url = "http://www.abc.com"; String requestContentString = "request body"; @@ -649,9 +580,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) ) ) ) @@ -673,8 +602,7 @@ public void testXContent() throws IOException { null, "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), - null, - new ErrorResponseParser("$.error.message", "inference_id") + null ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -693,9 +621,6 @@ public void testXContent() throws IOException { "response": { "json_parser": { "text_embeddings": "$.result.embeddings[*].embedding" - }, - "error_parser": { - "path": "$.error.message" } }, "rate_limit": { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index fb6c50f1bd9c4..85420b66ac56c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; @@ -152,14 +151,7 @@ private static Map createServiceSettingsMap(TaskType taskType) { CustomServiceSettings.REQUEST, new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, "request body")), CustomServiceSettings.RESPONSE, - new HashMap<>( - Map.of( - CustomServiceSettings.JSON_PARSER, - createResponseParserMap(taskType), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) - ) - ) + new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(taskType))) ) ); @@ -245,8 +237,7 @@ private static CustomModel createInternalEmbeddingModel( QueryParameters.EMPTY, "\"input\":\"${input}\"", parser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) @@ -254,8 +245,6 @@ private static CustomModel createInternalEmbeddingModel( } private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { - var inferenceId = "inference_id"; - return new CustomModel( "model_id", taskType, @@ -267,8 +256,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa QueryParameters.EMPTY, "\"input\":\"${input}\"", customResponseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 06bfc0b1f6956..f28f33c1b4e18 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.QueryParameters; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -63,8 +62,7 @@ public void testCreateRequest() throws IOException { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -117,8 +115,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { ), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -165,8 +162,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -213,8 +209,7 @@ public void testCreateRequest_HandlesQuery() throws IOException { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -256,8 +251,7 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -288,8 +282,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java deleted file mode 100644 index e52d7d9d0ff69..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.custom.response; - -import org.apache.http.HttpResponse; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH; -import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.mock; - -public class ErrorResponseParserTests extends ESTestCase { - - public static ErrorResponseParser createRandom() { - return new ErrorResponseParser("$." + randomAlphaOfLength(5), randomAlphaOfLength(5)); - } - - public void testFromMap() { - var validation = new ValidationException(); - var parser = ErrorResponseParser.fromMap( - new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), - "scope", - "inference_id", - validation - ); - - assertThat(parser, is(new ErrorResponseParser("$.error.message", "inference_id"))); - } - - public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { - var validation = new ValidationException(); - var exception = expectThrows( - ValidationException.class, - () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), "scope", "inference_id", validation) - ); - - assertThat(exception.getMessage(), is("Validation Failed: 1: [scope.error_parser] does not contain the required setting [path];")); - } - - public void testToXContent() throws IOException { - var entity = new ErrorResponseParser("$.error.message", "inference_id"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - { - builder.startObject(); - entity.toXContent(builder, null); - builder.endObject(); - } - String xContentResult = Strings.toString(builder); - - var expected = XContentHelper.stripWhitespace(""" - { - "error_parser": { - "path": "$.error.message" - } - } - """); - - assertThat(xContentResult, is(expected)); - } - - public void testErrorResponse_ExtractsError() throws IOException { - var result = getMockResult(""" - { - "error": { - "message": "test_error_message" - } - }"""); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(result); - assertThat(error, is(new ErrorResponse("test_error_message"))); - } - - public void testFromResponse_WithOtherFieldsPresent() throws IOException { - String responseJson = """ - { - "error": { - "message": "You didn't provide an API key", - "type": "invalid_request_error", - "param": null, - "code": null - } - } - """; - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(getMockResult(responseJson)); - - assertThat(error, is(new ErrorResponse("You didn't provide an API key"))); - } - - public void testFromResponse_noMessage() throws IOException { - String responseJson = """ - { - "error": { - "type": "not_found_error" - } - } - """; - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(getMockResult(responseJson)); - - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"error\":{\"type\":\"not_found_error\"}}]")); - assertTrue(error.errorStructureFound()); - } - - public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException { - var mockResult = getMockResult(""" - {"noerror":true}"""); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(mockResult); - - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"noerror\":true}]")); - } - - public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() { - var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string")); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(result); - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [not a json string]")); - } - - private static HttpResult getMockResult(String jsonString) throws IOException { - var response = mock(HttpResponse.class); - return new HttpResult(response, Strings.toUTF8Bytes(XContentHelper.stripWhitespace(jsonString))); - } -} From b34f414c849ec8a54a97dd1e4a0762219682e8c2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 2 Jun 2025 13:16:32 -0400 Subject: [PATCH 2/4] Adding test for lack of error parsing logic --- .../services/custom/CustomServiceTests.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 85420b66ac56c..c93d6c1f48aa7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.services.custom; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceServiceResults; @@ -269,6 +271,42 @@ private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddi : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; } + public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = "error"; + + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson)); + + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + exception.getMessage(), + is( + Strings.format( + "Received an unsuccessful status code for request " + + "from inference entity id [inference_id] status [400]. Error message: [%s]", + responseJson + ) + ) + ); + } + } + public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOException { try (var service = createService(threadPool, clientManager)) { String responseJson = """ From ed8318166c1a2c0b540dcf354243279ccc665be4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 16 Jun 2025 10:16:46 -0400 Subject: [PATCH 3/4] Adding transport version check --- .../java/org/elasticsearch/TransportVersions.java | 2 ++ .../services/custom/CustomServiceSettings.java | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 8bf8a94fccfe0..db74ebed74585 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -196,6 +196,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48); public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49); public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_51); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -298,6 +299,7 @@ static TransportVersion def(int id) { public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00); public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00); public static final TransportVersion PROJECT_DELETION_GLOBAL_BLOCK = def(9_098_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_099_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index f11acec774fd1..8b8b270db3bd9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -223,6 +223,12 @@ public CustomServiceSettings(StreamInput in) throws IOException { requestContentString = in.readString(); responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) + && in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { + // Read the error parsing fields for backwards compatibility + in.readString(); + in.readString(); + } } @Override @@ -337,6 +343,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(requestContentString); out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) + && out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { + // Write empty strings for backwards compatibility for the error parsing fields + out.writeString(""); + out.writeString(""); + } } @Override From ea1a42fc7dcacd5ffe43e2b731f4f3dd43cd464d Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 17 Jun 2025 15:33:37 -0400 Subject: [PATCH 4/4] Wrapping string in try/catch and adding test --- .../custom/CustomResponseHandler.java | 12 +++- .../custom/CustomResponseHandlerTests.java | 62 +++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java index 98d4dd9ab4c2e..ab67b8e726eb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.custom; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -24,9 +25,14 @@ * Defines how to handle various response types returned from the custom integration. */ public class CustomResponseHandler extends BaseResponseHandler { - private static final Function ERROR_PARSER = (httpResult) -> new ErrorResponse( - new String(httpResult.body(), StandardCharsets.UTF_8) - ); + // default for testing + static final Function ERROR_PARSER = (httpResult) -> { + try { + return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8)); + } catch (Exception e) { + return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage())); + } + }; public CustomResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, ERROR_PARSER); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java new file mode 100644 index 0000000000000..e96bf5452e18a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.services.custom.CustomResponseHandler.ERROR_PARSER; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CustomResponseHandlerTests extends ESTestCase { + + public void testErrorBodyParser() throws IOException { + var expected = XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, "message")); + + assertThat(ERROR_PARSER.apply(createResult(400, "message")), is(new ErrorResponse(expected))); + } + + private static HttpResult createResult(int statusCode, String message) throws IOException { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + String responseJson = XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, message)); + + return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8)); + } +}