diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java index ea829b2eb7bac..98cc6dfb8a1ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java @@ -10,8 +10,10 @@ import org.elasticsearch.common.Strings; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.regex.Pattern; /** @@ -78,6 +80,8 @@ * [1, 2] * ] * } + * + * The array field names would be {@code ["embeddings", "embedding"} * * * This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array. @@ -123,10 +127,28 @@ public class MapPathExtractor { private static final String DOLLAR = "$"; // default for testing - static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)"); - static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)"); + static final Pattern DOT_FIELD_PATTERN = Pattern.compile("^\\.([^.\\[]+)(.*)"); + static final Pattern ARRAY_WILDCARD_PATTERN = Pattern.compile("^\\[\\*\\](.*)"); + public static final String UNKNOWN_FIELD_NAME = "unknown"; + + /** + * A result object that tries to match up the field names parsed from the passed in path and the result + * extracted from the passed in map. + * @param extractedObject represents the extracted result from the map + * @param traversedFields a list of field names in order as they're encountered while navigating through the nested objects + */ + public record Result(Object extractedObject, List traversedFields) { + public String getArrayFieldName(int index) { + // if the index is out of bounds we'll return a default value + if (traversedFields.size() <= index || index < 0) { + return UNKNOWN_FIELD_NAME; + } + + return traversedFields.get(index); + } + } - public static Object extract(Map data, String path) { + public static Result extract(Map data, String path) { if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) { return null; } @@ -139,16 +161,41 @@ public static Object extract(Map data, String path) { throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath)); } - return navigate(data, cleanedPath); + var fieldNames = new LinkedHashSet(); + + return new Result(navigate(data, cleanedPath, new FieldNameInfo("", "", fieldNames)), fieldNames.stream().toList()); } - private static Object navigate(Object current, String remainingPath) { - if (current == null || remainingPath == null || remainingPath.isEmpty()) { + private record FieldNameInfo(String currentPath, String fieldName, Set traversedFields) { + void addTraversedField(String fieldName) { + traversedFields.add(createPath(fieldName)); + } + + void addCurrentField() { + traversedFields.add(currentPath); + } + + FieldNameInfo descend(String newFieldName) { + var newLocation = createPath(newFieldName); + return new FieldNameInfo(newLocation, newFieldName, traversedFields); + } + + private String createPath(String newFieldName) { + if (Strings.isNullOrEmpty(currentPath)) { + return newFieldName; + } else { + return currentPath + "." + newFieldName; + } + } + } + + private static Object navigate(Object current, String remainingPath, FieldNameInfo fieldNameInfo) { + if (current == null || Strings.isNullOrEmpty(remainingPath)) { return current; } - var dotFieldMatcher = dotFieldPattern.matcher(remainingPath); - var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath); + var dotFieldMatcher = DOT_FIELD_PATTERN.matcher(remainingPath); + var arrayWildcardMatcher = ARRAY_WILDCARD_PATTERN.matcher(remainingPath); if (dotFieldMatcher.matches()) { String field = dotFieldMatcher.group(1); @@ -168,7 +215,12 @@ private static Object navigate(Object current, String remainingPath) { throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field)); } - return navigate(currentMap.get(field), nextPath); + // Handle the case where the path was $.result.text or $.result[*].key + if (Strings.isNullOrEmpty(nextPath)) { + fieldNameInfo.addTraversedField(field); + } + + return navigate(currentMap.get(field), nextPath, fieldNameInfo.descend(field)); } else { throw new IllegalArgumentException( Strings.format( @@ -182,10 +234,12 @@ private static Object navigate(Object current, String remainingPath) { } else if (arrayWildcardMatcher.matches()) { String nextPath = arrayWildcardMatcher.group(1); if (current instanceof List list) { + fieldNameInfo.addCurrentField(); + List results = new ArrayList<>(); for (Object item : list) { - Object result = navigate(item, nextPath); + Object result = navigate(item, nextPath, fieldNameInfo); if (result != null) { results.add(result); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java index be9669c331371..7fc272931e7fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java @@ -34,4 +34,16 @@ public String getErrorMessage() { public boolean errorStructureFound() { return errorStructureFound; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ErrorResponse that = (ErrorResponse) o; + return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage); + } + + @Override + public int hashCode() { + return Objects.hash(errorMessage, errorStructureFound); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java new file mode 100644 index 0000000000000..d0f9faf283aef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +public class CustomServiceSettings { + public static final String NAME = "custom_service_settings"; + public static final String URL = "url"; + public static final String HEADERS = "headers"; + public static final String REQUEST = "request"; + public static final String REQUEST_CONTENT = "content"; + public static final String RESPONSE = "response"; + public static final String JSON_PARSER = "json_parser"; + public static final String ERROR_PARSER = "error_parser"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java new file mode 100644 index 0000000000000..99b035ef056c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParser.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; + +public abstract class BaseCustomResponseParser implements CustomResponseParser { + + @Override + public InferenceServiceResults parse(HttpResult response) throws IOException { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var map = jsonParser.map(); + + return transform(map); + } + } + + protected abstract T transform(Map extractedField); + + static List validateList(Object obj, String fieldName) { + validateNonNull(obj, fieldName); + + if (obj instanceof List == false) { + throw new IllegalArgumentException( + Strings.format( + "Extracted field [%s] is an invalid type, expected a list but received [%s]", + fieldName, + obj.getClass().getSimpleName() + ) + ); + } + + return (List) obj; + } + + static void validateNonNull(Object obj, String fieldName) { + Objects.requireNonNull(obj, Strings.format("Failed to parse field [%s], extracted field was null", fieldName)); + } + + static Map validateMap(Object obj, String fieldName) { + validateNonNull(obj, fieldName); + + if (obj instanceof Map == false) { + throw new IllegalArgumentException( + Strings.format( + "Extracted field [%s] is an invalid type, expected a map but received [%s]", + fieldName, + obj.getClass().getSimpleName() + ) + ); + } + + var keys = ((Map) obj).keySet(); + for (var key : keys) { + if (key instanceof String == false) { + throw new IllegalStateException( + Strings.format( + "Extracted field [%s] map has an invalid key type. Expected a string but received [%s]", + fieldName, + key.getClass().getSimpleName() + ) + ); + } + } + + @SuppressWarnings("unchecked") + var result = (Map) obj; + return result; + } + + static List convertToListOfFloats(Object obj, String fieldName) { + return castList(validateList(obj, fieldName), BaseCustomResponseParser::toFloat, fieldName); + } + + static Float toFloat(Object obj, String fieldName) { + return toNumber(obj, fieldName).floatValue(); + } + + private static Number toNumber(Object obj, String fieldName) { + if (obj instanceof Number == false) { + throw new IllegalArgumentException( + Strings.format("Unable to convert field [%s] of type [%s] to Number", fieldName, obj.getClass().getSimpleName()) + ); + } + + return ((Number) obj); + } + + static List convertToListOfIntegers(Object obj, String fieldName) { + return castList(validateList(obj, fieldName), BaseCustomResponseParser::toInteger, fieldName); + } + + private static Integer toInteger(Object obj, String fieldName) { + return toNumber(obj, fieldName).intValue(); + } + + static List castList(List items, BiFunction converter, String fieldName) { + validateNonNull(items, fieldName); + + List resultList = new ArrayList<>(); + for (int i = 0; i < items.size(); i++) { + try { + resultList.add(converter.apply(items.get(i), fieldName)); + } catch (Exception e) { + throw new IllegalStateException(Strings.format("Failed to parse list entry [%d], error: %s", i, e.getMessage()), e); + } + } + + return resultList; + } + + static T toType(Object obj, Class type, String fieldName) { + validateNonNull(obj, fieldName); + + if (type.isInstance(obj) == false) { + throw new IllegalArgumentException( + Strings.format( + "Unable to convert field [%s] of type [%s] to [%s]", + fieldName, + obj.getClass().getSimpleName(), + type.getSimpleName() + ) + ); + } + + return type.cast(obj); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java new file mode 100644 index 0000000000000..762556fb381ed --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParser.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.common.MapPathExtractor; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; + +public class CompletionResponseParser extends BaseCustomResponseParser { + + public static final String NAME = "completion_response_parser"; + public static final String COMPLETION_PARSER_RESULT = "completion_result"; + + private final String completionResultPath; + + public static CompletionResponseParser fromMap(Map responseParserMap, ValidationException validationException) { + var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CompletionResponseParser(path); + } + + public CompletionResponseParser(String completionResultPath) { + this.completionResultPath = Objects.requireNonNull(completionResultPath); + } + + public CompletionResponseParser(StreamInput in) throws IOException { + this.completionResultPath = in.readString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(completionResultPath); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(JSON_PARSER); + { + builder.field(COMPLETION_PARSER_RESULT, completionResultPath); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CompletionResponseParser that = (CompletionResponseParser) o; + return Objects.equals(completionResultPath, that.completionResultPath); + } + + @Override + public int hashCode() { + return Objects.hash(completionResultPath); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public ChatCompletionResults transform(Map map) { + var result = MapPathExtractor.extract(map, completionResultPath); + var extractedField = result.extractedObject(); + + validateNonNull(extractedField, completionResultPath); + + if (extractedField instanceof List extractedList) { + var completionList = castList(extractedList, (obj, fieldName) -> toType(obj, String.class, fieldName), completionResultPath); + return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList()); + } else if (extractedField instanceof String extractedString) { + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(extractedString))); + } else { + throw new IllegalArgumentException( + Strings.format( + "Extracted field [%s] from path [%s] is an invalid type, expected a list or a string but received [%s]", + result.getArrayFieldName(0), + completionResultPath, + extractedField.getClass().getSimpleName() + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java new file mode 100644 index 0000000000000..3a421307d76a8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseParser.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; + +public interface CustomResponseParser extends ToXContentFragment, NamedWriteable { + InferenceServiceResults parse(HttpResult response) throws IOException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java new file mode 100644 index 0000000000000..d05fa68595b3a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.MapPathExtractor; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType; + +public class ErrorResponseParser implements ToXContentFragment, Function { + + public static final String MESSAGE_PATH = "path"; + + private final String messagePath; + + public static ErrorResponseParser fromMap(Map responseParserMap, ValidationException validationException) { + var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ErrorResponseParser(path); + } + + public ErrorResponseParser(String messagePath) { + this.messagePath = Objects.requireNonNull(messagePath); + } + + public ErrorResponseParser(StreamInput in) throws IOException { + this.messagePath = in.readString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(messagePath); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(ERROR_PARSER); + { + builder.field(MESSAGE_PATH, messagePath); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ErrorResponseParser that = (ErrorResponseParser) o; + return Objects.equals(messagePath, that.messagePath); + } + + @Override + public int hashCode() { + return Objects.hash(messagePath); + } + + @Override + public ErrorResponse apply(HttpResult httpResult) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, httpResult.body()) + ) { + var map = jsonParser.map(); + + // NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic + // if we find the top level error field we'll return a response with an empty message but indicate + // that we found the structure of the error object. Here if we're missing the final field we will return + // a ErrorResponse.UNDEFINED_ERROR which will indicate that we did not find the structure even if for example + // the outer error field does exist, but it doesn't contain the nested field we were looking for. + // If in the future we want the previous behavior, we can add a new message_path field or something and have + // the current path field point to the field that indicates whether we found an error object. + var errorText = toType(MapPathExtractor.extract(map, messagePath).extractedObject(), String.class, messagePath); + return new ErrorResponse(errorText); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/NoopResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/NoopResponseParser.java new file mode 100644 index 0000000000000..c01086cd83ee0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/NoopResponseParser.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; + +public record NoopResponseParser() implements CustomResponseParser { + + public static final String NAME = "noop_response_parser"; + public static final NoopResponseParser INSTANCE = new NoopResponseParser(); + + public static NoopResponseParser fromMap() { + return new NoopResponseParser(); + } + + public NoopResponseParser(StreamInput in) { + this(); + } + + public void writeTo(StreamOutput out) throws IOException {} + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public InferenceServiceResults parse(HttpResult result) { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java new file mode 100644 index 0000000000000..18d3cbbad051b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParser.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.common.MapPathExtractor; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; + +public class RerankResponseParser extends BaseCustomResponseParser { + + public static final String NAME = "rerank_response_parser"; + public static final String RERANK_PARSER_SCORE = "relevance_score"; + public static final String RERANK_PARSER_INDEX = "reranked_index"; + public static final String RERANK_PARSER_DOCUMENT_TEXT = "document_text"; + + private final String relevanceScorePath; + private final String rerankIndexPath; + private final String documentTextPath; + + public static RerankResponseParser fromMap(Map responseParserMap, ValidationException validationException) { + + var relevanceScore = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException); + var rerankIndex = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException); + var documentText = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new RerankResponseParser(relevanceScore, rerankIndex, documentText); + } + + public RerankResponseParser(String relevanceScorePath) { + this(relevanceScorePath, null, null); + } + + public RerankResponseParser(String relevanceScorePath, @Nullable String rerankIndexPath, @Nullable String documentTextPath) { + this.relevanceScorePath = Objects.requireNonNull(relevanceScorePath); + this.rerankIndexPath = rerankIndexPath; + this.documentTextPath = documentTextPath; + } + + public RerankResponseParser(StreamInput in) throws IOException { + this.relevanceScorePath = in.readString(); + this.rerankIndexPath = in.readOptionalString(); + this.documentTextPath = in.readOptionalString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(relevanceScorePath); + out.writeOptionalString(rerankIndexPath); + out.writeOptionalString(documentTextPath); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(JSON_PARSER); + { + builder.field(RERANK_PARSER_SCORE, relevanceScorePath); + if (rerankIndexPath != null) { + builder.field(RERANK_PARSER_INDEX, rerankIndexPath); + } + + if (documentTextPath != null) { + builder.field(RERANK_PARSER_DOCUMENT_TEXT, documentTextPath); + } + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RerankResponseParser that = (RerankResponseParser) o; + return Objects.equals(relevanceScorePath, that.relevanceScorePath) + && Objects.equals(rerankIndexPath, that.rerankIndexPath) + && Objects.equals(documentTextPath, that.documentTextPath); + } + + @Override + public int hashCode() { + return Objects.hash(relevanceScorePath, rerankIndexPath, documentTextPath); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public RankedDocsResults transform(Map map) { + var scores = extractScores(map); + var indices = extractIndices(map); + var documents = extractDocuments(map); + + if (indices != null && indices.size() != scores.size()) { + throw new IllegalStateException( + Strings.format( + "The number of index fields [%d] was not the same as the number of scores [%d]", + indices.size(), + scores.size() + ) + ); + } + + if (documents != null && documents.size() != scores.size()) { + throw new IllegalStateException( + Strings.format( + "The number of document fields [%d] was not the same as the number of scores [%d]", + documents.size(), + scores.size() + ) + ); + } + + var rankedDocs = new ArrayList(); + for (int i = 0; i < scores.size(); i++) { + var index = indices != null ? indices.get(i) : i; + var score = scores.get(i); + var document = documents != null ? documents.get(i) : null; + rankedDocs.add(new RankedDocsResults.RankedDoc(index, score, document)); + } + + return new RankedDocsResults(rankedDocs); + } + + private List extractScores(Map map) { + try { + var result = MapPathExtractor.extract(map, relevanceScorePath); + return convertToListOfFloats(result.extractedObject(), result.getArrayFieldName(0)); + } catch (Exception e) { + throw new IllegalStateException(Strings.format("Failed to parse rerank scores, error: %s", e.getMessage()), e); + } + } + + private List extractIndices(Map map) { + if (rerankIndexPath != null) { + try { + var indexResult = MapPathExtractor.extract(map, rerankIndexPath); + return convertToListOfIntegers(indexResult.extractedObject(), indexResult.getArrayFieldName(0)); + } catch (Exception e) { + throw new IllegalStateException(Strings.format("Failed to parse rerank indices, error: %s", e.getMessage()), e); + } + } + + return null; + } + + private List extractDocuments(Map map) { + try { + if (documentTextPath != null) { + var documentResult = MapPathExtractor.extract(map, documentTextPath); + var documentFieldName = documentResult.getArrayFieldName(0); + return castList( + validateList(documentResult.extractedObject(), documentFieldName), + (obj, fieldName) -> toType(obj, String.class, fieldName), + documentFieldName + ); + } + } catch (Exception e) { + throw new IllegalStateException(Strings.format("Failed to parse rerank documents, error: %s", e.getMessage()), e); + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java new file mode 100644 index 0000000000000..b6c83fd7fbfc6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParser.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.common.MapPathExtractor; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; + +public class SparseEmbeddingResponseParser extends BaseCustomResponseParser { + + public static final String NAME = "sparse_embedding_response_parser"; + public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path"; + public static final String SPARSE_EMBEDDING_WEIGHT_PATH = "weight_path"; + + private final String tokenPath; + private final String weightPath; + + public static SparseEmbeddingResponseParser fromMap(Map responseParserMap, ValidationException validationException) { + var tokenPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, JSON_PARSER, validationException); + + var weightPath = extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, JSON_PARSER, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new SparseEmbeddingResponseParser(tokenPath, weightPath); + } + + public SparseEmbeddingResponseParser(String tokenPath, String weightPath) { + this.tokenPath = Objects.requireNonNull(tokenPath); + this.weightPath = Objects.requireNonNull(weightPath); + } + + public SparseEmbeddingResponseParser(StreamInput in) throws IOException { + this.tokenPath = in.readString(); + this.weightPath = in.readString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(tokenPath); + out.writeString(weightPath); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(JSON_PARSER); + { + builder.field(SPARSE_EMBEDDING_TOKEN_PATH, tokenPath); + builder.field(SPARSE_EMBEDDING_WEIGHT_PATH, weightPath); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SparseEmbeddingResponseParser that = (SparseEmbeddingResponseParser) o; + return Objects.equals(tokenPath, that.tokenPath) && Objects.equals(weightPath, that.weightPath); + } + + @Override + public int hashCode() { + return Objects.hash(tokenPath, weightPath); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected SparseEmbeddingResults transform(Map map) { + // These will be List> + var tokenResult = MapPathExtractor.extract(map, tokenPath); + var tokens = validateList(tokenResult.extractedObject(), tokenResult.getArrayFieldName(0)); + + // These will be List> + var weightResult = MapPathExtractor.extract(map, weightPath); + var weights = validateList(weightResult.extractedObject(), weightResult.getArrayFieldName(0)); + + validateListsSize(tokens, weights); + + var tokenEntryFieldName = tokenResult.getArrayFieldName(1); + var weightEntryFieldName = weightResult.getArrayFieldName(1); + var embeddings = new ArrayList(); + for (int responseCounter = 0; responseCounter < tokens.size(); responseCounter++) { + try { + var tokenEntryList = validateList(tokens.get(responseCounter), tokenEntryFieldName); + var weightEntryList = validateList(weights.get(responseCounter), weightEntryFieldName); + + validateListsSize(tokenEntryList, weightEntryList); + + embeddings.add(createEmbedding(tokenEntryList, weightEntryList, weightEntryFieldName)); + } catch (Exception e) { + throw new IllegalStateException( + Strings.format("Failed to parse sparse embedding entry [%d], error: %s", responseCounter, e.getMessage()), + e + ); + } + } + + return new SparseEmbeddingResults(Collections.unmodifiableList(embeddings)); + } + + private static void validateListsSize(List tokens, List weights) { + if (tokens.size() != weights.size()) { + throw new IllegalStateException( + Strings.format( + "The extracted tokens list is size [%d] but the weights list is size [%d]. The list sizes must be equal.", + tokens.size(), + weights.size() + ) + ); + } + } + + private static SparseEmbeddingResults.Embedding createEmbedding( + List tokenEntryList, + List weightEntryList, + String weightFieldName + ) { + var weightedTokens = new ArrayList(); + + for (int embeddingCounter = 0; embeddingCounter < tokenEntryList.size(); embeddingCounter++) { + var token = tokenEntryList.get(embeddingCounter); + var weight = weightEntryList.get(embeddingCounter); + + // Alibaba can return a token id which is an integer and needs to be converted to a string + var tokenIdAsString = token.toString(); + try { + var weightAsFloat = toFloat(weight, weightFieldName); + weightedTokens.add(new WeightedToken(tokenIdAsString, weightAsFloat)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + Strings.format("Failed to parse weight item: [%d] of array, error: %s", embeddingCounter, e.getMessage()), + e + ); + } + } + + return new SparseEmbeddingResults.Embedding(weightedTokens, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java new file mode 100644 index 0000000000000..fe5b4ec236282 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParser.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.common.MapPathExtractor; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; + +public class TextEmbeddingResponseParser extends BaseCustomResponseParser { + + public static final String NAME = "text_embedding_response_parser"; + public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings"; + + private final String textEmbeddingsPath; + + public static TextEmbeddingResponseParser fromMap(Map responseParserMap, ValidationException validationException) { + var path = extractRequiredString(responseParserMap, TEXT_EMBEDDING_PARSER_EMBEDDINGS, JSON_PARSER, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TextEmbeddingResponseParser(path); + } + + public TextEmbeddingResponseParser(String textEmbeddingsPath) { + this.textEmbeddingsPath = Objects.requireNonNull(textEmbeddingsPath); + } + + public TextEmbeddingResponseParser(StreamInput in) throws IOException { + this.textEmbeddingsPath = in.readString(); + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(textEmbeddingsPath); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(JSON_PARSER); + { + builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TextEmbeddingResponseParser that = (TextEmbeddingResponseParser) o; + return Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath); + } + + @Override + public int hashCode() { + return Objects.hash(textEmbeddingsPath); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected TextEmbeddingFloatResults transform(Map map) { + var extractedResult = MapPathExtractor.extract(map, textEmbeddingsPath); + var mapResultsList = validateList(extractedResult.extractedObject(), extractedResult.getArrayFieldName(0)); + + var embeddings = new ArrayList(mapResultsList.size()); + + for (int i = 0; i < mapResultsList.size(); i++) { + try { + var entry = mapResultsList.get(i); + var embeddingsAsListFloats = convertToListOfFloats(entry, extractedResult.getArrayFieldName(1)); + embeddings.add(TextEmbeddingFloatResults.Embedding.of(embeddingsAsListFloats)); + } catch (Exception e) { + throw new IllegalArgumentException( + Strings.format("Failed to parse text embedding entry [%d], error: %s", i, e.getMessage()), + e + ); + } + } + + return new TextEmbeddingFloatResults(embeddings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java index cd084ca224798..047c0c8d647fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java @@ -21,7 +21,15 @@ public void testExtract_RetrievesListOfLists() { Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4)))) ); - assertThat(MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), is(List.of(List.of(1, 2), List.of(3, 4)))); + assertThat( + MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), + is( + new MapPathExtractor.Result( + List.of(List.of(1, 2), List.of(3, 4)), + List.of("result.embeddings", "result.embeddings.embedding") + ) + ) + ); } public void testExtract_IteratesListOfMapsToListOfStrings() { @@ -32,7 +40,29 @@ public void testExtract_IteratesListOfMapsToListOfStrings() { assertThat( MapPathExtractor.extract(input, "$.result[*].key[*]"), - is(List.of(List.of("value1", "value2"), List.of("value3", "value4"))) + is( + new MapPathExtractor.Result( + List.of(List.of("value1", "value2"), List.of("value3", "value4")), + List.of("result", "result.key") + ) + ) + ); + } + + public void testExtract_IteratesListOfMapsToListOfStrings_WithoutFinalArraySyntax() { + Map input = Map.of( + "result", + List.of(Map.of("key", List.of("value1", "value2")), Map.of("key", List.of("value3", "value4"))) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result[*].key"), + is( + new MapPathExtractor.Result( + List.of(List.of("value1", "value2"), List.of("value3", "value4")), + List.of("result", "result.key") + ) + ) ); } @@ -45,7 +75,15 @@ public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() { ) ); - assertThat(MapPathExtractor.extract(input, "$.result[*].key[*].a"), is(List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d)))); + assertThat( + MapPathExtractor.extract(input, "$.result[*].key[*].a"), + is( + new MapPathExtractor.Result( + List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d)), + List.of("result", "result.key", "result.key.a") + ) + ) + ); } public void testExtract_ReturnsNullForEmptyList() { @@ -128,36 +166,36 @@ public void testExtract_ThrowsException_WhenHasArraySyntaxButIsAMap() { public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty() { Map input = Map.of("result", List.of()); - assertThat(MapPathExtractor.extract(input, "$.result"), is(List.of())); + assertThat(MapPathExtractor.extract(input, "$.result"), is(new MapPathExtractor.Result(List.of(), List.of("result")))); } public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty_PathIncludesArray() { Map input = Map.of("result", List.of()); - assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(List.of())); + assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(new MapPathExtractor.Result(List.of(), List.of("result")))); } public void testDotFieldPattern() { { - var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc.123"); + var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc.123"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is("abc")); assertThat(matcher.group(2), is(".123")); } { - var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[*].123"); + var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc[*].123"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is("abc")); assertThat(matcher.group(2), is("[*].123")); } { - var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[.123"); + var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc[.123"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is("abc")); assertThat(matcher.group(2), is("[.123")); } { - var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc"); + var matcher = MapPathExtractor.DOT_FIELD_PATTERN.matcher(".abc"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is("abc")); assertThat(matcher.group(2), is("")); @@ -166,21 +204,21 @@ public void testDotFieldPattern() { public void testArrayWildcardPattern() { { - var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*].abc.123"); + var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[*].abc.123"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is(".abc.123")); } { - var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*]"); + var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[*]"); assertTrue(matcher.matches()); assertThat(matcher.group(1), is("")); } { - var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[1].abc"); + var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[1].abc"); assertFalse(matcher.matches()); } { - var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[].abc"); + var matcher = MapPathExtractor.ARRAY_WILDCARD_PATTERN.matcher("[].abc"); assertFalse(matcher.matches()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParserTests.java new file mode 100644 index 0000000000000..b1cf2ba0934eb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/BaseCustomResponseParserTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.castList; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.convertToListOfFloats; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toFloat; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.validateList; +import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.validateMap; +import static org.hamcrest.Matchers.is; + +public class BaseCustomResponseParserTests extends ESTestCase { + public void testValidateNonNull_ThrowsException_WhenPassedNull() { + var exception = expectThrows(NullPointerException.class, () -> BaseCustomResponseParser.validateNonNull(null, "field")); + assertThat(exception.getMessage(), is("Failed to parse field [field], extracted field was null")); + } + + public void testValidateList_ThrowsException_WhenPassedAnObjectThatIsNotAList() { + var exception = expectThrows(IllegalArgumentException.class, () -> validateList(new Object(), "field")); + assertThat(exception.getMessage(), is("Extracted field [field] is an invalid type, expected a list but received [Object]")); + } + + public void testValidateList_ReturnsList() { + Object obj = List.of("abc", "123"); + assertThat(validateList(obj, "field"), is(List.of("abc", "123"))); + } + + public void testConvertToListOfFloats_ThrowsException_WhenAnItemInTheListIsNotANumber() { + var list = List.of(1, "hello"); + + var exception = expectThrows(IllegalStateException.class, () -> convertToListOfFloats(list, "field")); + assertThat( + exception.getMessage(), + is("Failed to parse list entry [1], error: Unable to convert field [field] of type [String] to Number") + ); + } + + public void testConvertToListOfFloats_ReturnsList() { + var list = List.of(1, 1.1f, -2.0d, new AtomicInteger(1)); + + assertThat(convertToListOfFloats(list, "field"), is(List.of(1f, 1.1f, -2f, 1f))); + } + + public void testCastList() { + var list = List.of("abc", "123", 1, 2.2d); + + assertThat(castList(list, (obj, fieldName) -> obj.toString(), "field"), is(List.of("abc", "123", "1", "2.2"))); + } + + public void testCastList_ThrowsException() { + var list = List.of("abc"); + + var exception = expectThrows(IllegalStateException.class, () -> castList(list, (obj, fieldName) -> { + throw new IllegalArgumentException("failed"); + }, "field")); + + assertThat(exception.getMessage(), is("Failed to parse list entry [0], error: failed")); + } + + public void testValidateMap() { + assertThat(validateMap(Map.of("abc", 123), "field"), is(Map.of("abc", 123))); + } + + public void testValidateMap_ThrowsException_WhenObjectIsNotAMap() { + var exception = expectThrows(IllegalArgumentException.class, () -> validateMap("hello", "field")); + assertThat(exception.getMessage(), is("Extracted field [field] is an invalid type, expected a map but received [String]")); + } + + public void testValidateMap_ThrowsException_WhenKeysAreNotStrings() { + var exception = expectThrows(IllegalStateException.class, () -> validateMap(Map.of("key", "value", 1, "abc"), "field")); + assertThat( + exception.getMessage(), + is("Extracted field [field] map has an invalid key type. Expected a string but received [Integer]") + ); + } + + public void testToFloat() { + assertThat(toFloat(1, "field"), is(1f)); + } + + public void testToFloat_AtomicLong() { + assertThat(toFloat(new AtomicLong(100), "field"), is(100f)); + } + + public void testToFloat_Double() { + assertThat(toFloat(1.123d, "field"), is(1.123f)); + } + + public void testToType() { + Object obj = "hello"; + assertThat(toType(obj, String.class, "field"), is("hello")); + } + + public void testToType_List() { + Object obj = List.of(123, 456); + assertThat(toType(obj, List.class, "field"), is(List.of(123, 456))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java new file mode 100644 index 0000000000000..46cb23a4ceaa5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CompletionResponseParserTests.java @@ -0,0 +1,304 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser.COMPLETION_PARSER_RESULT; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CompletionResponseParserTests extends AbstractBWCWireSerializationTestCase { + + public static CompletionResponseParser createRandom() { + return new CompletionResponseParser("$." + randomAlphaOfLength(5)); + } + + public void testFromMap() { + var validation = new ValidationException(); + var parser = CompletionResponseParser.fromMap(new HashMap<>(Map.of(COMPLETION_PARSER_RESULT, "$.result[*].text")), validation); + + assertThat(parser, is(new CompletionResponseParser("$.result[*].text"))); + } + + public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> CompletionResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].text")), validation) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [json_parser] does not contain the required setting [completion_result];") + ); + } + + public void testToXContent() throws IOException { + var entity = new CompletionResponseParser("$.result[*].text"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "json_parser": { + "completion_result": "$.result[*].text" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testParse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": [ + { + "text":"completion results" + } + ], + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result[*].text"); + ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("completion results"))))); + } + + public void testParse_String() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": { + "text":"completion results" + }, + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result.text"); + ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults, is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("completion results"))))); + } + + public void testParse_MultipleResults() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": [ + { + "text":"completion results" + }, + { + "text":"completion results2" + } + ], + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result[*].text"); + ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new ChatCompletionResults( + List.of(new ChatCompletionResults.Result("completion results"), new ChatCompletionResults.Result("completion results2")) + ) + ) + ); + } + + public void testParse_AnthropicFormat() throws IOException { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "result" + }, + { + "type": "text", + "text": "result2" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + var parser = new CompletionResponseParser("$.content[*].text"); + ChatCompletionResults parsedResults = (ChatCompletionResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is(new ChatCompletionResults(List.of(new ChatCompletionResults.Result("result"), new ChatCompletionResults.Result("result2")))) + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": "invalid_field", + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result[*].text"); + var exception = expectThrows( + IllegalArgumentException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Current path [[*].text] matched the array field pattern " + + "but the current object is not a list, found invalid type [String] instead." + ) + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotListOfStrings() { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": ["string", true], + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result"); + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is("Failed to parse list entry [1], error: Unable to convert field [$.result] of type [Boolean] to [String]") + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotAListOrString() { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": 123, + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + var parser = new CompletionResponseParser("$.result"); + var exception = expectThrows( + IllegalArgumentException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is("Extracted field [result] from path [$.result] is an invalid type, expected a list or a string but received [Integer]") + ); + } + + @Override + protected CompletionResponseParser mutateInstanceForVersion(CompletionResponseParser instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return CompletionResponseParser::new; + } + + @Override + protected CompletionResponseParser createTestInstance() { + return createRandom(); + } + + @Override + protected CompletionResponseParser mutateInstance(CompletionResponseParser instance) throws IOException { + return randomValueOtherThan(instance, CompletionResponseParserTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java new file mode 100644 index 0000000000000..56987407e02ac --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParserTests.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser.MESSAGE_PATH; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.Mockito.mock; + +public class ErrorResponseParserTests extends ESTestCase { + + public static ErrorResponseParser createRandom() { + return new ErrorResponseParser("$." + randomAlphaOfLength(5)); + } + + public void testFromMap() { + var validation = new ValidationException(); + var parser = ErrorResponseParser.fromMap(new HashMap<>(Map.of(MESSAGE_PATH, "$.error.message")), validation); + + assertThat(parser, is(new ErrorResponseParser("$.error.message"))); + } + + public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> ErrorResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.error.message")), validation) + ); + + assertThat(exception.getMessage(), is("Validation Failed: 1: [error_parser] does not contain the required setting [path];")); + } + + public void testToXContent() throws IOException { + var entity = new ErrorResponseParser("$.error.message"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "error_parser": { + "path": "$.error.message" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testErrorResponse_ExtractsError() throws IOException { + var result = getMockResult(""" + { + "error": { + "message": "test_error_message" + } + }"""); + + var parser = new ErrorResponseParser("$.error.message"); + var error = parser.apply(result); + assertThat(error, is(new ErrorResponse("test_error_message"))); + } + + public void testFromResponse_WithOtherFieldsPresent() throws IOException { + String responseJson = """ + { + "error": { + "message": "You didn't provide an API key", + "type": "invalid_request_error", + "param": null, + "code": null + } + } + """; + + var parser = new ErrorResponseParser("$.error.message"); + var error = parser.apply(getMockResult(responseJson)); + + assertThat(error, is(new ErrorResponse("You didn't provide an API key"))); + } + + public void testFromResponse_noMessage() throws IOException { + String responseJson = """ + { + "error": { + "type": "not_found_error" + } + } + """; + + var parser = new ErrorResponseParser("$.error.message"); + var error = parser.apply(getMockResult(responseJson)); + + assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + assertThat(error.getErrorMessage(), is("")); + assertFalse(error.errorStructureFound()); + } + + public void testErrorResponse_ReturnsUndefinedObjectIfNoError() throws IOException { + var mockResult = getMockResult(""" + {"noerror":true}"""); + + var parser = new ErrorResponseParser("$.error.message"); + var error = parser.apply(mockResult); + + assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + } + + public void testErrorResponse_ReturnsUndefinedObjectIfNotJson() { + var result = new HttpResult(mock(HttpResponse.class), Strings.toUTF8Bytes("not a json string")); + + var parser = new ErrorResponseParser("$.error.message"); + var error = parser.apply(result); + assertThat(error, sameInstance(ErrorResponse.UNDEFINED_ERROR)); + } + + private static HttpResult getMockResult(String jsonString) throws IOException { + var response = mock(HttpResponse.class); + return new HttpResult(response, Strings.toUTF8Bytes(XContentHelper.stripWhitespace(jsonString))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java new file mode 100644 index 0000000000000..523d15ec2a805 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/RerankResponseParserTests.java @@ -0,0 +1,456 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_DOCUMENT_TEXT; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_INDEX; +import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class RerankResponseParserTests extends AbstractBWCWireSerializationTestCase { + + public static RerankResponseParser createRandom() { + var indexPath = randomBoolean() ? "$." + randomAlphaOfLength(5) : null; + var documentTextPath = randomBoolean() ? "$." + randomAlphaOfLength(5) : null; + return new RerankResponseParser("$." + randomAlphaOfLength(5), indexPath, documentTextPath); + } + + public void testFromMap() { + var validation = new ValidationException(); + var parser = RerankResponseParser.fromMap( + new HashMap<>( + Map.of( + RERANK_PARSER_SCORE, + "$.result.scores[*].score", + RERANK_PARSER_INDEX, + "$.result.scores[*].index", + RERANK_PARSER_DOCUMENT_TEXT, + "$.result.scores[*].document_text" + ) + ), + validation + ); + + assertThat( + parser, + is(new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document_text")) + ); + } + + public void testFromMap_WithoutOptionalFields() { + var validation = new ValidationException(); + var parser = RerankResponseParser.fromMap(new HashMap<>(Map.of(RERANK_PARSER_SCORE, "$.result.scores[*].score")), validation); + + assertThat(parser, is(new RerankResponseParser("$.result.scores[*].score", null, null))); + } + + public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> RerankResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [json_parser] does not contain the required setting [relevance_score];") + ); + } + + public void testToXContent() throws IOException { + var entity = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document_text"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "json_parser": { + "relevance_score": "$.result.scores[*].score", + "reranked_index": "$.result.scores[*].index", + "document_text": "$.result.scores[*].document_text" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_WithoutOptionalFields() throws IOException { + var entity = new RerankResponseParser("$.result.scores[*].score"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "json_parser": { + "relevance_score": "$.result.scores[*].score" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testParse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null); + RankedDocsResults parsedResults = (RankedDocsResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new RankedDocsResults( + List.of(new RankedDocsResults.RankedDoc(1, 1.37f, null), new RankedDocsResults.RankedDoc(0, -0.3f, null)) + ) + ) + ); + } + + public void testParse_ThrowsException_WhenIndex_IsInvalid() { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":"abc", + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse rerank indices, error: Failed to parse list entry [0], " + + "error: Unable to convert field [result.scores] of type [String] to Number" + ) + ); + } + + public void testParse_ThrowsException_WhenScore_IsInvalid() { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": true + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse rerank scores, error: Failed to parse list entry [0], " + + "error: Unable to convert field [result.scores] of type [Boolean] to Number" + ) + ); + } + + public void testParse_ThrowsException_WhenDocument_IsInvalid() { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 0.2, + "document": 1 + }, + { + "index":0, + "score": -0.3, + "document": "a document" + } + ] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", "$.result.scores[*].document"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse rerank documents, error: Failed to parse list entry [0], error: " + + "Unable to convert field [result.scores] of type [Integer] to [String]" + ) + ); + } + + public void testParse_ThrowsException_WhenIndices_ListSizeDoesNotMatchScores() { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "indices": [1], + "scores": [0.2, 0.3], + "documents": ["a", "b"] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores", "$.result.indices", "$.result.documents"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat(exception.getMessage(), is("The number of index fields [1] was not the same as the number of scores [2]")); + } + + public void testParse_ThrowsException_WhenDocuments_ListSizeDoesNotMatchScores() { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "indices": [1, 0], + "scores": [0.2, 0.3], + "documents": ["a"] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores", "$.result.indices", "$.result.documents"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat(exception.getMessage(), is("The number of document fields [1] was not the same as the number of scores [2]")); + } + + public void testParse_WithoutIndex() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "score": 1.37 + }, + { + "score": -0.3 + } + ] + } + } + """; + + var parser = new RerankResponseParser("$.result.scores[*].score", null, null); + RankedDocsResults parsedResults = (RankedDocsResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new RankedDocsResults( + List.of(new RankedDocsResults.RankedDoc(0, 1.37f, null), new RankedDocsResults.RankedDoc(1, -0.3f, null)) + ) + ) + ); + } + + public void testParse_CohereResponseFormat() throws IOException { + String responseJson = """ + { + "index": "44873262-1315-4c06-8433-fdc90c9790d0", + "results": [ + { + "document": { + "text": "Washington, D.C.." + }, + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": { + "text": "Capital punishment has existed in the United States since beforethe United States was a country. " + }, + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": { + "text": "Carson City is the capital city of the American state of Nevada." + }, + "index": 0, + "relevance_score": 0.10194652 + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "search_units": 1 + } + } + } + """; + + var parser = new RerankResponseParser("$.results[*].relevance_score", "$.results[*].index", "$.results[*].document.text"); + RankedDocsResults parsedResults = (RankedDocsResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307f, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198f, + "Capital punishment has existed in the United States since beforethe United States was a country. " + ), + new RankedDocsResults.RankedDoc(0, 0.10194652f, "Carson City is the capital city of the American state of Nevada.") + ) + ) + ) + ); + } + + @Override + protected RerankResponseParser mutateInstanceForVersion(RerankResponseParser instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return RerankResponseParser::new; + } + + @Override + protected RerankResponseParser createTestInstance() { + return createRandom(); + } + + @Override + protected RerankResponseParser mutateInstance(RerankResponseParser instance) throws IOException { + return randomValueOtherThan(instance, RerankResponseParserTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java new file mode 100644 index 0000000000000..c4b69ae8c8b19 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/SparseEmbeddingResponseParserTests.java @@ -0,0 +1,349 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; +import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class SparseEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { + + public static SparseEmbeddingResponseParser createRandom() { + return new SparseEmbeddingResponseParser(randomAlphaOfLength(5), randomAlphaOfLength(5)); + } + + public void testFromMap() { + var validation = new ValidationException(); + var parser = SparseEmbeddingResponseParser.fromMap( + new HashMap<>( + Map.of( + SPARSE_EMBEDDING_TOKEN_PATH, + "$.result[*].embeddings[*].token", + SPARSE_EMBEDDING_WEIGHT_PATH, + "$.result[*].embeddings[*].weight" + ) + ), + validation + ); + + assertThat(parser, is(new SparseEmbeddingResponseParser("$.result[*].embeddings[*].token", "$.result[*].embeddings[*].weight"))); + } + + public void testFromMap_ThrowsException_WhenRequiredFieldsAreNotPresent() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> SparseEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("not_path", "$.result[*].embeddings")), validation) + ); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: [json_parser] does not contain the required setting [token_path];" + + "2: [json_parser] does not contain the required setting [weight_path];" + ) + ); + } + + public void testToXContent() throws IOException { + var entity = new SparseEmbeddingResponseParser("$.result.path.token", "$.result.path.weight"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "json_parser": { + "token_path": "$.result.path.token", + "weight_path": "$.result.path.weight" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testParse() throws IOException { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.101 + }, + { + "tokenId": 163040, + "weight": 0.28417 + } + ] + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ); + SparseEmbeddingResults parsedResults = (SparseEmbeddingResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new SparseEmbeddingResults( + List.of( + new SparseEmbeddingResults.Embedding( + List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)), + false + ) + ) + ) + ) + ); + } + + public void testParse_ThrowsException_WhenTheTokenField_IsNotAnArray() { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "tokenId": 6, + "weight": [0.101] + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse sparse embedding entry [0], error: Extracted field [result.sparse_embeddings.tokenId] " + + "is an invalid type, expected a list but received [Integer]" + ) + ); + } + + public void testParse_ThrowsException_WhenTheTokenArraySize_AndWeightArraySize_AreDifferent() { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "tokenId": [6, 7], + "weight": [0.101] + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse sparse embedding entry [0], error: The extracted tokens list is size [2] " + + "but the weights list is size [1]. The list sizes must be equal." + ) + ); + } + + public void testParse_ThrowsException_WhenTheWeightValue_IsNotAFloat() { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "tokenId": [6], + "weight": [true] + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse sparse embedding entry [0], error: Failed to parse weight item: " + + "[0] of array, error: Unable to convert field [result.sparse_embeddings.weight] of type [Boolean] to Number" + ) + ); + } + + public void testParse_ThrowsException_WhenTheWeightField_IsNotAnArray() { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "tokenId": [6], + "weight": 0.101 + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser("$.result.sparse_embeddings[*].tokenId", "$.result.sparse_embeddings[*].weight"); + + var exception = expectThrows( + IllegalStateException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse sparse embedding entry [0], error: Extracted field [result.sparse_embeddings.weight] " + + "is an invalid type, expected a list but received [Double]" + ) + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotFormattedCorrectly() { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "6": 0.101 + }, + { + "163040": 0.28417 + } + ] + } + ] + } + } + """; + + var parser = new SparseEmbeddingResponseParser( + "$.result.sparse_embeddings[*].embedding[*].tokenId", + "$.result.sparse_embeddings[*].embedding[*].weight" + ); + var exception = expectThrows( + IllegalArgumentException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat(exception.getMessage(), is("Unable to find field [tokenId] in map")); + } + + @Override + protected SparseEmbeddingResponseParser mutateInstanceForVersion(SparseEmbeddingResponseParser instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return SparseEmbeddingResponseParser::new; + } + + @Override + protected SparseEmbeddingResponseParser createTestInstance() { + return createRandom(); + } + + @Override + protected SparseEmbeddingResponseParser mutateInstance(SparseEmbeddingResponseParser instance) throws IOException { + return randomValueOtherThan(instance, SparseEmbeddingResponseParserTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java new file mode 100644 index 0000000000000..b240e07a66336 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/TextEmbeddingResponseParserTests.java @@ -0,0 +1,263 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class TextEmbeddingResponseParserTests extends AbstractBWCWireSerializationTestCase { + + public static TextEmbeddingResponseParser createRandom() { + return new TextEmbeddingResponseParser("$." + randomAlphaOfLength(5)); + } + + public void testFromMap() { + var validation = new ValidationException(); + var parser = TextEmbeddingResponseParser.fromMap( + new HashMap<>(Map.of(TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result[*].embeddings")), + validation + ); + + assertThat(parser, is(new TextEmbeddingResponseParser("$.result[*].embeddings"))); + } + + public void testFromMap_ThrowsException_WhenRequiredFieldIsNotPresent() { + var validation = new ValidationException(); + var exception = expectThrows( + ValidationException.class, + () -> TextEmbeddingResponseParser.fromMap(new HashMap<>(Map.of("some_field", "$.result[*].embeddings")), validation) + ); + + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [json_parser] does not contain " + "the required setting [text_embeddings];") + ); + } + + public void testToXContent() throws IOException { + var entity = new TextEmbeddingResponseParser("$.result.path"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + { + builder.startObject(); + entity.toXContent(builder, null); + builder.endObject(); + } + String xContentResult = Strings.toString(builder); + + var expected = XContentHelper.stripWhitespace(""" + { + "json_parser": { + "text_embeddings": "$.result.path" + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testParse() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))) + ); + } + + public void testParse_MultipleEmbeddings() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 1, + -2 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) parser.parse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults, + is( + new TextEmbeddingFloatResults( + List.of( + new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), + new TextEmbeddingFloatResults.Embedding(new float[] { 1F, -2F }) + ) + ) + ) + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotAListOfFloats() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1, + -0.015288644 + ] + }, + { + "object": "embedding", + "index": 0, + "embedding": [ + true, + -0.015288644 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var exception = expectThrows( + IllegalArgumentException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse text embedding entry [1], error: Failed to parse list entry [0], error:" + + " Unable to convert field [data.embedding] of type [Boolean] to Number" + ) + ); + } + + public void testParse_ThrowsException_WhenExtractedField_IsNotAList() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": 1 + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var parser = new TextEmbeddingResponseParser("$.data[*].embedding"); + var exception = expectThrows( + IllegalArgumentException.class, + () -> parser.parse(new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))) + ); + + assertThat( + exception.getMessage(), + is( + "Failed to parse text embedding entry [0], error: Extracted field [data.embedding] " + + "is an invalid type, expected a list but received [Integer]" + ) + ); + } + + @Override + protected TextEmbeddingResponseParser mutateInstanceForVersion(TextEmbeddingResponseParser instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return TextEmbeddingResponseParser::new; + } + + @Override + protected TextEmbeddingResponseParser createTestInstance() { + return createRandom(); + } + + @Override + protected TextEmbeddingResponseParser mutateInstance(TextEmbeddingResponseParser instance) throws IOException { + return randomValueOtherThan(instance, TextEmbeddingResponseParserTests::createRandom); + } +}