Skip to content
Merged
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_51);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -301,6 +302,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SECURITY_CLOUD_API_KEY_REALM_AND_TYPE = def(9_099_0_00);
public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00);
public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) {
}
}

private static ResponseHandler createCustomHandler(CustomModel model) {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser());
private static ResponseHandler createCustomHandler() {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
}

public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
Expand All @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool)
private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
this.model = model;
this.handler = createCustomHandler(model);
this.handler = createCustomHandler();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;

import java.nio.charset.StandardCharsets;
import java.util.function.Function;

/**
* Defines how to handle various response types returned from the custom integration.
*/
public class CustomResponseHandler extends BaseResponseHandler {
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {
super(requestType, parseFunction, errorParser);
private static final Function<HttpResult, ErrorResponse> ERROR_PARSER = (httpResult) -> new ErrorResponse(
new String(httpResult.body(), StandardCharsets.UTF_8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any case where this might fail to parse and we'd want some sort of generic fallback string? From the java docs it seems "The behavior of this constructor when the given bytes are not valid in the given charset is unspecified." (here) so I'm not quite clear on how this might react if it can't parse the bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, I'll wrap it in a try/catch and return something generic 👍

);

public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ERROR_PARSER);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
serviceSettings.getQueryParameters(),
serviceSettings.getRequestContentString(),
serviceSettings.getResponseJsonParser(),
serviceSettings.rateLimitSettings(),
serviceSettings.getErrorParser()
serviceSettings.rateLimitSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
Expand Down Expand Up @@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
Expand Down Expand Up @@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap(

var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);

Map<String, Object> errorParserMap = extractRequiredMap(
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
ERROR_PARSER,
RESPONSE_SCOPE,
validationException
);

var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException);

RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
Expand All @@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap(
context
);

if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
if (responseParserMap == null || jsonParserMap == null) {
throw validationException;
}

throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand All @@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap(
queryParams,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down Expand Up @@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final String requestContentString;
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final ErrorResponseParser errorParser;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -218,8 +204,7 @@ public CustomServiceSettings(
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser
@Nullable RateLimitSettings rateLimitSettings
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -228,7 +213,6 @@ public CustomServiceSettings(
this.requestContentString = Objects.requireNonNull(requestContentString);
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.errorParser = Objects.requireNonNull(errorParser);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -239,7 +223,12 @@ public CustomServiceSettings(StreamInput in) throws IOException {
requestContentString = in.readString();
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);
errorParser = new ErrorResponseParser(in);
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Read the error parsing fields for backwards compatibility
in.readString();
in.readString();
}
}

@Override
Expand Down Expand Up @@ -287,10 +276,6 @@ public CustomResponseParser getResponseJsonParser() {
return responseJsonParser;
}

public ErrorResponseParser getErrorParser() {
return errorParser;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -331,7 +316,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
builder.startObject(RESPONSE);
{
responseJsonParser.toXContent(builder, params);
errorParser.toXContent(builder, params);
}
builder.endObject();

Expand Down Expand Up @@ -359,7 +343,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestContentString);
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);
errorParser.writeTo(out);
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Write empty strings for backwards compatibility for the error parsing fields
out.writeString("");
out.writeString("");
}
}

@Override
Expand All @@ -373,8 +362,7 @@ public boolean equals(Object o) {
&& Objects.equals(queryParameters, that.queryParameters)
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(errorParser, that.errorParser);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}

@Override
Expand All @@ -386,8 +374,7 @@ public int hashCode() {
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
Expand Down Expand Up @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
QueryParameters.EMPTY,
requestContentString,
responseParser,
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.junit.After;
Expand Down Expand Up @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
null,
requestContentString,
new RerankResponseParser("$.result.score"),
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

var model = CustomModelTests.createModel(
Expand Down
Loading