diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e5535d85601b4..29e0e41e856b1 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -197,6 +197,7 @@ static TransportVersion def(int id) { public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49); public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50); public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -302,6 +303,7 @@ static TransportVersion def(int id) { public static final TransportVersion SECURITY_CLOUD_API_KEY_REALM_AND_TYPE = def(9_099_0_00); public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00); public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00); + public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index a112e7db26fe3..fa1e753ec5eff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) { } } - private static ResponseHandler createCustomHandler(CustomModel model) { - return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser()); + private static ResponseHandler createCustomHandler() { + return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse); } public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) { @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) private CustomRequestManager(CustomModel model, ThreadPool threadPool) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); this.model = model; - this.handler = createCustomHandler(model); + this.handler = createCustomHandler(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java index 14a962b112ccd..ab67b8e726eb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java @@ -8,21 +8,34 @@ package org.elasticsearch.xpack.inference.services.custom; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; + +import java.nio.charset.StandardCharsets; +import java.util.function.Function; /** * Defines how to handle various response types returned from the custom integration. */ public class CustomResponseHandler extends BaseResponseHandler { - public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) { - super(requestType, parseFunction, errorParser); + // default for testing + static final Function ERROR_PARSER = (httpResult) -> { + try { + return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8)); + } catch (Exception e) { + return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage())); + } + }; + + public CustomResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ERROR_PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 5e9aef099f622..77b852f43cd8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom serviceSettings.getQueryParameters(), serviceSettings.getRequestContentString(), serviceSettings.getResponseJsonParser(), - serviceSettings.rateLimitSettings(), - serviceSettings.getErrorParser() + serviceSettings.rateLimitSettings() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java index 0d5129b6c759c..8b8b270db3bd9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; @@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser public static final String REQUEST = "request"; public static final String RESPONSE = "response"; public static final String JSON_PARSER = "json_parser"; - public static final String ERROR_PARSER = "error_parser"; private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE); @@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap( var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException); - Map errorParserMap = extractRequiredMap( - Objects.requireNonNullElse(responseParserMap, new HashMap<>()), - ERROR_PARSER, - RESPONSE_SCOPE, - validationException - ); - - var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, DEFAULT_RATE_LIMIT_SETTINGS, @@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap( context ); - if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) { + if (responseParserMap == null || jsonParserMap == null) { throw validationException; } throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME); throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME); - throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME); if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap( queryParams, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } @@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private final String requestContentString; private final CustomResponseParser responseJsonParser; private final RateLimitSettings rateLimitSettings; - private final ErrorResponseParser errorParser; public CustomServiceSettings( TextEmbeddingSettings textEmbeddingSettings, @@ -218,8 +204,7 @@ public CustomServiceSettings( @Nullable QueryParameters queryParameters, String requestContentString, CustomResponseParser responseJsonParser, - @Nullable RateLimitSettings rateLimitSettings, - ErrorResponseParser errorParser + @Nullable RateLimitSettings rateLimitSettings ) { this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings); this.url = Objects.requireNonNull(url); @@ -228,7 +213,6 @@ public CustomServiceSettings( this.requestContentString = Objects.requireNonNull(requestContentString); this.responseJsonParser = Objects.requireNonNull(responseJsonParser); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); - this.errorParser = Objects.requireNonNull(errorParser); } public CustomServiceSettings(StreamInput in) throws IOException { @@ -239,7 +223,12 @@ public CustomServiceSettings(StreamInput in) throws IOException { requestContentString = in.readString(); responseJsonParser = in.readNamedWriteable(CustomResponseParser.class); rateLimitSettings = new RateLimitSettings(in); - errorParser = new ErrorResponseParser(in); + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) + && in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { + // Read the error parsing fields for backwards compatibility + in.readString(); + in.readString(); + } } @Override @@ -287,10 +276,6 @@ public CustomResponseParser getResponseJsonParser() { return responseJsonParser; } - public ErrorResponseParser getErrorParser() { - return errorParser; - } - @Override public RateLimitSettings rateLimitSettings() { return rateLimitSettings; @@ -331,7 +316,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder builder.startObject(RESPONSE); { responseJsonParser.toXContent(builder, params); - errorParser.toXContent(builder, params); } builder.endObject(); @@ -359,7 +343,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(requestContentString); out.writeNamedWriteable(responseJsonParser); rateLimitSettings.writeTo(out); - errorParser.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING) + && out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) { + // Write empty strings for backwards compatibility for the error parsing fields + out.writeString(""); + out.writeString(""); + } } @Override @@ -373,8 +362,7 @@ public boolean equals(Object o) { && Objects.equals(queryParameters, that.queryParameters) && Objects.equals(requestContentString, that.requestContentString) && Objects.equals(responseJsonParser, that.responseJsonParser) - && Objects.equals(rateLimitSettings, that.rateLimitSettings) - && Objects.equals(errorParser, that.errorParser); + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override @@ -386,8 +374,7 @@ public int hashCode() { queryParameters, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java deleted file mode 100644 index 51fb8b1486a82..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.custom.response; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.common.MapPathExtractor; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; -import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER; -import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType; - -public class ErrorResponseParser implements ToXContentFragment, Function { - - private static final Logger logger = LogManager.getLogger(ErrorResponseParser.class); - public static final String MESSAGE_PATH = "path"; - - private final String messagePath; - private final String inferenceId; - - public static ErrorResponseParser fromMap( - Map responseParserMap, - String scope, - String inferenceId, - ValidationException validationException - ) { - var path = extractRequiredString(responseParserMap, MESSAGE_PATH, String.join(".", scope, ERROR_PARSER), validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new ErrorResponseParser(path, inferenceId); - } - - public ErrorResponseParser(String messagePath, String inferenceId) { - this.messagePath = Objects.requireNonNull(messagePath); - this.inferenceId = Objects.requireNonNull(inferenceId); - } - - public ErrorResponseParser(StreamInput in) throws IOException { - this.messagePath = in.readString(); - this.inferenceId = in.readString(); - } - - public void writeTo(StreamOutput out) throws IOException { - out.writeString(messagePath); - out.writeString(inferenceId); - } - - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(ERROR_PARSER); - { - builder.field(MESSAGE_PATH, messagePath); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ErrorResponseParser that = (ErrorResponseParser) o; - return Objects.equals(messagePath, that.messagePath); - } - - @Override - public int hashCode() { - return Objects.hash(messagePath); - } - - @Override - public ErrorResponse apply(HttpResult httpResult) { - try ( - XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) - ) { - var map = jsonParser.map(); - // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic - // if we find the top level error field we'll return a response with an empty message but indicate - // that we found the structure of the error object. Here if we're missing the final field we will return - // a ErrorResponse.UNDEFINED_ERROR which will indicate that we did not find the structure even if for example - // the outer error field does exist, but it doesn't contain the nested field we were looking for. - // If in the future we want the previous behavior, we can add a new message_path field or something and have - // the current path field point to the field that indicates whether we found an error object. - var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath); - return new ErrorResponse(errorText); - } catch (Exception e) { - var resultAsString = new String(httpResult.body(), StandardCharsets.UTF_8); - - logger.info( - Strings.format( - "Failed to parse error object for custom service inference id [%s], message path: [%s], result as string: [%s]", - inferenceId, - messagePath, - resultAsString - ), - e - ); - - return new ErrorResponse(Strings.format("Unable to parse the error, response body: [%s]", resultAsString)); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java index 412066cb3bdd5..d98cb3d90a0e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.hamcrest.MatcherAssert; @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r QueryParameters.EMPTY, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java index 484585bdff68d..efcce23097fd6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.junit.After; @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java new file mode 100644 index 0000000000000..e96bf5452e18a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandlerTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.services.custom.CustomResponseHandler.ERROR_PARSER; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CustomResponseHandlerTests extends ESTestCase { + + public void testErrorBodyParser() throws IOException { + var expected = XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, "message")); + + assertThat(ERROR_PARSER.apply(createResult(400, "message")), is(new ErrorResponse(expected))); + } + + private static HttpResult createResult(int statusCode, String message) throws IOException { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + String responseJson = XContentHelper.stripWhitespace(Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, message)); + + return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 9e1d3a8f4c8f8..3e2289e418f76 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; @@ -74,8 +73,6 @@ public static CustomServiceSettings createRandom(String inputUrl) { default -> new NoopResponseParser(); }; - var errorParser = new ErrorResponseParser("$.error.message", randomAlphaOfLength(5)); - RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); return new CustomServiceSettings( @@ -90,8 +87,7 @@ public static CustomServiceSettings createRandom(String inputUrl) { queryParameters, requestContentString, responseJsonParser, - rateLimitSettings, - errorParser + rateLimitSettings ); } @@ -133,9 +129,7 @@ public void testFromMap() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -160,8 +154,7 @@ public void testFromMap() { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"))), requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -186,9 +179,7 @@ public void testFromMap_WithOptionalsNotSpecified() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -208,8 +199,7 @@ public void testFromMap_WithOptionalsNotSpecified() { null, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -250,9 +240,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -277,8 +265,7 @@ public void testFromMap_RemovesNullValues_FromMaps() { null, requestContentString, responseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", "inference_id") + new RateLimitSettings(10_000) ) ) ); @@ -311,9 +298,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -360,9 +345,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -400,9 +383,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -434,9 +415,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -452,8 +431,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() { is( "Validation Failed: 1: [service_settings] does not contain the required setting [response];" + "2: [service_settings.response] does not contain the required setting [json_parser];" - + "3: [service_settings.response] does not contain the required setting [error_parser];" - + "4: Encountered a null input map while parsing field [path];" ) ); } @@ -481,9 +458,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() { "key", "value" ) - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) + ) ) ) ) @@ -522,8 +497,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")), "key", "value" ) @@ -545,46 +518,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() { ); } - public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() { - String url = "http://www.abc.com"; - String requestContentString = "request body"; - - var mapSettings = new HashMap( - Map.of( - CustomServiceSettings.URL, - url, - CustomServiceSettings.HEADERS, - new HashMap<>(Map.of("key", "value")), - CustomServiceSettings.REQUEST, - requestContentString, - CustomServiceSettings.RESPONSE, - new HashMap<>( - Map.of( - CustomServiceSettings.JSON_PARSER, - new HashMap<>( - Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) - ) - ) - ) - ); - - var exception = expectThrows( - ElasticsearchStatusException.class, - () -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id") - ); - - assertThat( - exception.getMessage(), - is( - "Configuration contains unknown settings [{key=value}] while parsing field [error_parser]" - + " for settings [custom_service_settings]" - ) - ); - } - public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { String url = "http://www.abc.com"; String requestContentString = "request body"; @@ -603,9 +536,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() { CustomServiceSettings.JSON_PARSER, new HashMap<>( Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") - ), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message", "key", "value")) + ) ) ) ) @@ -627,8 +558,7 @@ public void testXContent() throws IOException { null, "string", new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"), - null, - new ErrorResponseParser("$.error.message", "inference_id") + null ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -645,9 +575,6 @@ public void testXContent() throws IOException { "response": { "json_parser": { "text_embeddings": "$.result.embeddings[*].embedding" - }, - "error_parser": { - "path": "$.error.message" } }, "rate_limit": { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index f8d650144693d..dedc0d0e71ac9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.services.custom; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -31,7 +33,6 @@ import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; @@ -152,14 +153,7 @@ private static Map createServiceSettingsMap(TaskType taskType) { CustomServiceSettings.REQUEST, "request body", CustomServiceSettings.RESPONSE, - new HashMap<>( - Map.of( - CustomServiceSettings.JSON_PARSER, - createResponseParserMap(taskType), - CustomServiceSettings.ERROR_PARSER, - new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message")) - ) - ) + new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(taskType))) ) ); @@ -245,8 +239,7 @@ private static CustomModel createInternalEmbeddingModel( QueryParameters.EMPTY, "\"input\":\"${input}\"", parser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))) @@ -254,8 +247,6 @@ private static CustomModel createInternalEmbeddingModel( } private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) { - var inferenceId = "inference_id"; - return new CustomModel( "model_id", taskType, @@ -267,8 +258,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa QueryParameters.EMPTY, "\"input\":\"${input}\"", customResponseParser, - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ), new CustomTaskSettings(Map.of("key", "test_value")), new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))) @@ -281,6 +271,42 @@ private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddi : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; } + public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException { + try (var service = createService(threadPool, clientManager)) { + String responseJson = "error"; + + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJson)); + + var model = createInternalEmbeddingModel(new TextEmbeddingResponseParser("$.data[*].embedding"), getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("test input"), + false, + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + exception.getMessage(), + is( + Strings.format( + "Received an unsuccessful status code for request " + + "from inference entity id [inference_id] status [400]. Error message: [%s]", + responseJson + ) + ) + ); + } + } + public void testInfer_HandlesTextEmbeddingRequest_OpenAI_Format() throws IOException { try (var service = createService(threadPool, clientManager)) { String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 0f5daaf13af43..ca2726b043056 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.custom.QueryParameters; -import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser; import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -63,8 +62,7 @@ public void testCreateRequest() throws IOException { new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -117,8 +115,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() { ), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -165,8 +162,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))), requestContentString, new TextEmbeddingResponseParser("$.result.embeddings"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -213,8 +209,7 @@ public void testCreateRequest_HandlesQuery() throws IOException { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -256,8 +251,7 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( @@ -288,8 +282,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { null, requestContentString, new RerankResponseParser("$.result.score"), - new RateLimitSettings(10_000), - new ErrorResponseParser("$.error.message", inferenceId) + new RateLimitSettings(10_000) ); var model = CustomModelTests.createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java deleted file mode 100644 index e52d7d9d0ff69..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.custom.response; - -import org.apache.http.HttpResponse; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH; -import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.mock; - -public class ErrorResponseParserTests extends ESTestCase { - - public static ErrorResponseParser createRandom() { - return new ErrorResponseParser("$." + randomAlphaOfLength(5), randomAlphaOfLength(5)); - } - - public void testFromMap() { - var validation = new ValidationException(); - var parser = ErrorResponseParser.fromMap( - new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), - "scope", - "inference_id", - validation - ); - - assertThat(parser, is(new ErrorResponseParser("$.error.message", "inference_id"))); - } - - public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { - var validation = new ValidationException(); - var exception = expectThrows( - ValidationException.class, - () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), "scope", "inference_id", validation) - ); - - assertThat(exception.getMessage(), is("Validation Failed: 1: [scope.error_parser] does not contain the required setting [path];")); - } - - public void testToXContent() throws IOException { - var entity = new ErrorResponseParser("$.error.message", "inference_id"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - { - builder.startObject(); - entity.toXContent(builder, null); - builder.endObject(); - } - String xContentResult = Strings.toString(builder); - - var expected = XContentHelper.stripWhitespace(""" - { - "error_parser": { - "path": "$.error.message" - } - } - """); - - assertThat(xContentResult, is(expected)); - } - - public void testErrorResponse_ExtractsError() throws IOException { - var result = getMockResult(""" - { - "error": { - "message": "test_error_message" - } - }"""); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(result); - assertThat(error, is(new ErrorResponse("test_error_message"))); - } - - public void testFromResponse_WithOtherFieldsPresent() throws IOException { - String responseJson = """ - { - "error": { - "message": "You didn't provide an API key", - "type": "invalid_request_error", - "param": null, - "code": null - } - } - """; - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(getMockResult(responseJson)); - - assertThat(error, is(new ErrorResponse("You didn't provide an API key"))); - } - - public void testFromResponse_noMessage() throws IOException { - String responseJson = """ - { - "error": { - "type": "not_found_error" - } - } - """; - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(getMockResult(responseJson)); - - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"error\":{\"type\":\"not_found_error\"}}]")); - assertTrue(error.errorStructureFound()); - } - - public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException { - var mockResult = getMockResult(""" - {"noerror":true}"""); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(mockResult); - - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [{\"noerror\":true}]")); - } - - public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() { - var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string")); - - var parser = new ErrorResponseParser("$.error.message", "inference_id"); - var error = parser.apply(result); - assertThat(error.getErrorMessage(), is("Unable to parse the error, response body: [not a json string]")); - } - - private static HttpResult getMockResult(String jsonString) throws IOException { - var response = mock(HttpResponse.class); - return new HttpResult(response, Strings.toUTF8Bytes(XContentHelper.stripWhitespace(jsonString))); - } -}