Skip to content
Merged
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -302,6 +303,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SECURITY_CLOUD_API_KEY_REALM_AND_TYPE = def(9_099_0_00);
public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00);
public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) {
}
}

private static ResponseHandler createCustomHandler(CustomModel model) {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser());
private static ResponseHandler createCustomHandler() {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
}

public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
Expand All @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool)
private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
this.model = model;
this.handler = createCustomHandler(model);
this.handler = createCustomHandler();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,34 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;

import java.nio.charset.StandardCharsets;
import java.util.function.Function;

/**
* Defines how to handle various response types returned from the custom integration.
*/
public class CustomResponseHandler extends BaseResponseHandler {
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {
super(requestType, parseFunction, errorParser);
// default for testing
static final Function<HttpResult, ErrorResponse> ERROR_PARSER = (httpResult) -> {
try {
return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8));
} catch (Exception e) {
return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage()));
}
};

public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ERROR_PARSER);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
serviceSettings.getQueryParameters(),
serviceSettings.getRequestContentString(),
serviceSettings.getResponseJsonParser(),
serviceSettings.rateLimitSettings(),
serviceSettings.getErrorParser()
serviceSettings.rateLimitSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
Expand Down Expand Up @@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
Expand Down Expand Up @@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap(

var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);

Map<String, Object> errorParserMap = extractRequiredMap(
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
ERROR_PARSER,
RESPONSE_SCOPE,
validationException
);

var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException);

RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
Expand All @@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap(
context
);

if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
if (responseParserMap == null || jsonParserMap == null) {
throw validationException;
}

throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand All @@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap(
queryParams,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down Expand Up @@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final String requestContentString;
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final ErrorResponseParser errorParser;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -218,8 +204,7 @@ public CustomServiceSettings(
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser
@Nullable RateLimitSettings rateLimitSettings
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -228,7 +213,6 @@ public CustomServiceSettings(
this.requestContentString = Objects.requireNonNull(requestContentString);
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.errorParser = Objects.requireNonNull(errorParser);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -239,7 +223,12 @@ public CustomServiceSettings(StreamInput in) throws IOException {
requestContentString = in.readString();
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);
errorParser = new ErrorResponseParser(in);
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Read the error parsing fields for backwards compatibility
in.readString();
in.readString();
}
}

@Override
Expand Down Expand Up @@ -287,10 +276,6 @@ public CustomResponseParser getResponseJsonParser() {
return responseJsonParser;
}

public ErrorResponseParser getErrorParser() {
return errorParser;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -331,7 +316,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
builder.startObject(RESPONSE);
{
responseJsonParser.toXContent(builder, params);
errorParser.toXContent(builder, params);
}
builder.endObject();

Expand Down Expand Up @@ -359,7 +343,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestContentString);
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);
errorParser.writeTo(out);
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Write empty strings for backwards compatibility for the error parsing fields
out.writeString("");
out.writeString("");
}
}

@Override
Expand All @@ -373,8 +362,7 @@ public boolean equals(Object o) {
&& Objects.equals(queryParameters, that.queryParameters)
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(errorParser, that.errorParser);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}

@Override
Expand All @@ -386,8 +374,7 @@ public int hashCode() {
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
Expand Down Expand Up @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
QueryParameters.EMPTY,
requestContentString,
responseParser,
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.junit.After;
Expand Down Expand Up @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
null,
requestContentString,
new RerankResponseParser("$.result.score"),
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

var model = CustomModelTests.createModel(
Expand Down
Loading