-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Adding response parsers for custom service #127179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
53af9be
5b8c573
27d2a70
fa1dd88
90d2320
8685ddd
69c00ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.custom; | ||
|
|
||
| public class CustomServiceSettings { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not all of these values are referenced but they will be once I pull in the rest of the custom service changes in future PRs. |
||
| public static final String NAME = "custom_service_settings"; | ||
| public static final String URL = "url"; | ||
| public static final String HEADERS = "headers"; | ||
| public static final String REQUEST = "request"; | ||
| public static final String REQUEST_CONTENT = "content"; | ||
| public static final String RESPONSE = "response"; | ||
| public static final String JSON_PARSER = "json_parser"; | ||
| public static final String ERROR_PARSER = "error_parser"; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.custom.response; | ||
|
|
||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
| import java.util.function.Function; | ||
|
|
||
| public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser { | ||
|
|
||
| @Override | ||
| public InferenceServiceResults parse(HttpResult response) throws IOException { | ||
| try ( | ||
| XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) | ||
| .createParser(XContentParserConfiguration.EMPTY, response.body()) | ||
| ) { | ||
| var map = jsonParser.map(); | ||
|
|
||
| return transform(map); | ||
| } | ||
| } | ||
|
|
||
| protected abstract T transform(Map<String, Object> extractedField); | ||
|
|
||
| static List<?> validateList(Object obj) { | ||
| validateNonNull(obj); | ||
|
|
||
| if (obj instanceof List<?> == false) { | ||
| throw new IllegalArgumentException( | ||
| Strings.format("Extracted field is an invalid type, expected a list but received [%s]", obj.getClass().getSimpleName()) | ||
|
||
| ); | ||
| } | ||
|
|
||
| return (List<?>) obj; | ||
| } | ||
|
|
||
| static void validateNonNull(Object obj) { | ||
| Objects.requireNonNull(obj, "Failed to parse response, extracted field was null"); | ||
| } | ||
|
|
||
| static Map<String, Object> validateMap(Object obj) { | ||
| validateNonNull(obj); | ||
|
|
||
| if (obj instanceof Map<?, ?> == false) { | ||
| throw new IllegalArgumentException( | ||
| Strings.format("Extracted field is an invalid type, expected a map but received [%s]", obj.getClass().getSimpleName()) | ||
| ); | ||
| } | ||
|
|
||
| var keys = ((Map<?, ?>) obj).keySet(); | ||
| for (var key : keys) { | ||
| if (key instanceof String == false) { | ||
| throw new IllegalStateException( | ||
| Strings.format( | ||
| "Extracted map has an invalid key type. Expected a string but received [%s]", | ||
| key.getClass().getSimpleName() | ||
| ) | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| var result = (Map<String, Object>) obj; | ||
| return result; | ||
| } | ||
|
|
||
| static List<Float> convertToListOfFloats(Object obj) { | ||
| return validateAndCastList(validateList(obj), BaseCustomResponseParser::toFloat); | ||
| } | ||
|
|
||
| static Float toFloat(Object obj) { | ||
| return toNumber(obj).floatValue(); | ||
| } | ||
|
|
||
| private static Number toNumber(Object obj) { | ||
| if (obj instanceof Number == false) { | ||
| throw new IllegalArgumentException(Strings.format("Unable to convert type [%s] to Number", obj.getClass().getSimpleName())); | ||
| } | ||
|
|
||
| return ((Number) obj); | ||
| } | ||
|
|
||
| static List<Integer> convertToListOfIntegers(Object obj) { | ||
| return validateAndCastList(validateList(obj), BaseCustomResponseParser::toInteger); | ||
| } | ||
|
|
||
| private static Integer toInteger(Object obj) { | ||
| return toNumber(obj).intValue(); | ||
| } | ||
|
|
||
| static <T> List<T> validateAndCastList(List<?> items, Function<Object, T> converter) { | ||
| validateNonNull(items); | ||
|
|
||
| List<T> resultList = new ArrayList<>(); | ||
| for (var obj : items) { | ||
| resultList.add(converter.apply(obj)); | ||
| } | ||
|
|
||
| return resultList; | ||
| } | ||
|
|
||
| static <T> T toType(Object obj, Class<T> type) { | ||
| validateNonNull(obj); | ||
|
|
||
| if (type.isInstance(obj) == false) { | ||
| throw new IllegalArgumentException( | ||
| Strings.format("Unable to convert object of type [%s] to type [%s]", obj.getClass().getSimpleName(), type.getSimpleName()) | ||
| ); | ||
| } | ||
|
|
||
| return type.cast(obj); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.custom.response; | ||
|
|
||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.xcontent.XContentBuilder; | ||
| import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; | ||
| import org.elasticsearch.xpack.inference.common.MapPathExtractor; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
|
|
||
| import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; | ||
| import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; | ||
|
|
||
| public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompletionResults> { | ||
|
|
||
| public static final String NAME = "completion_response_parser"; | ||
| public static final String COMPLETION_PARSER_RESULT = "completion_result"; | ||
|
|
||
| private final String completionResultPath; | ||
|
|
||
| public static CompletionResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) { | ||
| var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException); | ||
|
|
||
| if (path == null) { | ||
| throw validationException; | ||
| } | ||
|
|
||
| return new CompletionResponseParser(path); | ||
| } | ||
|
|
||
| public CompletionResponseParser(String completionResultPath) { | ||
| this.completionResultPath = Objects.requireNonNull(completionResultPath); | ||
| } | ||
|
|
||
| public CompletionResponseParser(StreamInput in) throws IOException { | ||
| this.completionResultPath = in.readString(); | ||
| } | ||
|
|
||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeString(completionResultPath); | ||
| } | ||
|
|
||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| builder.startObject(JSON_PARSER); | ||
| { | ||
| builder.field(COMPLETION_PARSER_RESULT, completionResultPath); | ||
| } | ||
| builder.endObject(); | ||
| return builder; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) return true; | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| CompletionResponseParser that = (CompletionResponseParser) o; | ||
| return Objects.equals(completionResultPath, that.completionResultPath); | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(completionResultPath); | ||
| } | ||
|
|
||
| @Override | ||
| public String getWriteableName() { | ||
| return NAME; | ||
| } | ||
|
|
||
| @Override | ||
| public ChatCompletionResults transform(Map<String, Object> map) { | ||
| var extractedField = MapPathExtractor.extract(map, completionResultPath); | ||
|
|
||
| validateNonNull(extractedField); | ||
|
|
||
| if (extractedField instanceof List<?> extractedList) { | ||
| var completionList = validateAndCastList(extractedList, (obj) -> toType(obj, String.class)); | ||
| return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList()); | ||
| } else if (extractedField instanceof String extractedString) { | ||
| return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(extractedString))); | ||
| } else { | ||
| throw new IllegalArgumentException( | ||
| Strings.format( | ||
| "Extracted field is an invalid type, expected a list or a string but received [%s]", | ||
| extractedField.getClass().getSimpleName() | ||
| ) | ||
| ); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.custom.response; | ||
|
|
||
| import org.elasticsearch.common.io.stream.NamedWriteable; | ||
| import org.elasticsearch.inference.InferenceServiceResults; | ||
| import org.elasticsearch.xcontent.ToXContentFragment; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| public interface CustomResponseParser extends ToXContentFragment, NamedWriteable { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The response parsers need to be a named writeable because the service settings will use a factory to construct the appropriate response parser. The response parser is chosen based on the task type of the service which is provided in the PUT request.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll register these in a future PR
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we plan to handle streaming, since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah good point, I hadn't really thought about that. I think you're suggestion will work well though. |
||
| InferenceServiceResults parse(HttpResult response) throws IOException; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.custom.response; | ||
|
|
||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.xcontent.ToXContentFragment; | ||
| import org.elasticsearch.xcontent.XContentBuilder; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xcontent.XContentParser; | ||
| import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.inference.common.MapPathExtractor; | ||
| import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
| import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
| import java.util.function.Function; | ||
|
|
||
| import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; | ||
| import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER; | ||
| import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType; | ||
|
|
||
| public class ErrorResponseParser implements ToXContentFragment, Function<HttpResult, ErrorResponse> { | ||
|
|
||
| public static final String MESSAGE_PATH = "path"; | ||
|
|
||
| private final String messagePath; | ||
|
|
||
| public static ErrorResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) { | ||
| var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException); | ||
|
|
||
| if (path == null) { | ||
| throw validationException; | ||
| } | ||
|
|
||
| return new ErrorResponseParser(path); | ||
| } | ||
|
|
||
| public ErrorResponseParser(String messagePath) { | ||
| this.messagePath = Objects.requireNonNull(messagePath); | ||
| } | ||
|
|
||
| public ErrorResponseParser(StreamInput in) throws IOException { | ||
| this.messagePath = in.readString(); | ||
| } | ||
|
|
||
| public void writeTo(StreamOutput out) throws IOException { | ||
| out.writeString(messagePath); | ||
| } | ||
|
|
||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| builder.startObject(ERROR_PARSER); | ||
| { | ||
| builder.field(MESSAGE_PATH, messagePath); | ||
| } | ||
| builder.endObject(); | ||
| return builder; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) return true; | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| ErrorResponseParser that = (ErrorResponseParser) o; | ||
| return Objects.equals(messagePath, that.messagePath); | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(messagePath); | ||
| } | ||
|
|
||
| @Override | ||
| public ErrorResponse apply(HttpResult httpResult) { | ||
| try ( | ||
| XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) | ||
| .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) | ||
| ) { | ||
| var map = jsonParser.map(); | ||
|
|
||
| // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic | ||
| // if we find the top level error field we'll return a response with an empty message but indicate | ||
| // that we found the structure of the error object. Here if we're missing the final field we will return | ||
| // a ErrorResponse.UNDEFINED_ERROR with will indicate that we did not find the structure even if for example | ||
jonathan-buttner marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // the outer error field does exist, but it doesn't contain the nested field we were looking for. | ||
| // If in the future we want the previous behavior, we can add a new message_path field or something and have | ||
| // the current path field point to the field that indicates whether we found an error object. | ||
| var errorText = toType(MapPathExtractor.extract(map, messagePath), String.class); | ||
| return new ErrorResponse(errorText); | ||
| } catch (Exception e) { | ||
| // swallow the error | ||
| } | ||
|
|
||
| return ErrorResponse.UNDEFINED_ERROR; | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed these so the
assertThatcalls in the error parser tests can succeed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively could
ErrorResponsebe a record rather than an classThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would make it cleaner. At the moment
ErrorResponseis a base class and extended in a few places so we'd have to switch that to composition (I don't think we can extend a record 🤔 ).