Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import org.elasticsearch.common.Strings;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;

/**
Expand Down Expand Up @@ -78,6 +80,8 @@
* [1, 2]
* ]
* }
*
* The array field names would be {@code ["embeddings", "embedding"}
* </pre>
*
* This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
Expand Down Expand Up @@ -123,10 +127,28 @@ public class MapPathExtractor {
private static final String DOLLAR = "$";

// default for testing
static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
static final Pattern DOT_FIELD_PATTERN = Pattern.compile("^\\.([^.\\[]+)(.*)");
static final Pattern ARRAY_WILDCARD_PATTERN = Pattern.compile("^\\[\\*\\](.*)");
public static final String UNKNOWN_FIELD_NAME = "unknown";

/**
* A result object that tries to match up the field names parsed from the passed in path and the result
* extracted from the passed in map.
* @param extractedObject represents the extracted result from the map
* @param traversedFields a list of field names in order as they're encountered while navigating through the nested objects
*/
public record Result(Object extractedObject, List<String> traversedFields) {
public String getArrayFieldName(int index) {
// if the index is out of bounds we'll return a default value
if (traversedFields.size() <= index || index < 0) {
return UNKNOWN_FIELD_NAME;
}

return traversedFields.get(index);
}
}

public static Object extract(Map<String, Object> data, String path) {
public static Result extract(Map<String, Object> data, String path) {
if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
return null;
}
Expand All @@ -139,16 +161,41 @@ public static Object extract(Map<String, Object> data, String path) {
throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath));
}

return navigate(data, cleanedPath);
var fieldNames = new LinkedHashSet<String>();

return new Result(navigate(data, cleanedPath, new FieldNameInfo("", "", fieldNames)), fieldNames.stream().toList());
}

private static Object navigate(Object current, String remainingPath) {
if (current == null || remainingPath == null || remainingPath.isEmpty()) {
private record FieldNameInfo(String currentPath, String fieldName, Set<String> traversedFields) {
void addTraversedField(String fieldName) {
traversedFields.add(createPath(fieldName));
}

void addCurrentField() {
traversedFields.add(currentPath);
}

FieldNameInfo descend(String newFieldName) {
var newLocation = createPath(newFieldName);
return new FieldNameInfo(newLocation, newFieldName, traversedFields);
}

private String createPath(String newFieldName) {
if (Strings.isNullOrEmpty(currentPath)) {
return newFieldName;
} else {
return currentPath + "." + newFieldName;
}
}
}

private static Object navigate(Object current, String remainingPath, FieldNameInfo fieldNameInfo) {
if (current == null || Strings.isNullOrEmpty(remainingPath)) {
return current;
}

var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
var dotFieldMatcher = DOT_FIELD_PATTERN.matcher(remainingPath);
var arrayWildcardMatcher = ARRAY_WILDCARD_PATTERN.matcher(remainingPath);

if (dotFieldMatcher.matches()) {
String field = dotFieldMatcher.group(1);
Expand All @@ -168,7 +215,12 @@ private static Object navigate(Object current, String remainingPath) {
throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
}

return navigate(currentMap.get(field), nextPath);
// Handle the case where the path was $.result.text or $.result[*].key
if (Strings.isNullOrEmpty(nextPath)) {
fieldNameInfo.addTraversedField(field);
}

return navigate(currentMap.get(field), nextPath, fieldNameInfo.descend(field));
} else {
throw new IllegalArgumentException(
Strings.format(
Expand All @@ -182,10 +234,12 @@ private static Object navigate(Object current, String remainingPath) {
} else if (arrayWildcardMatcher.matches()) {
String nextPath = arrayWildcardMatcher.group(1);
if (current instanceof List<?> list) {
fieldNameInfo.addCurrentField();

List<Object> results = new ArrayList<>();

for (Object item : list) {
Object result = navigate(item, nextPath);
Object result = navigate(item, nextPath, fieldNameInfo);
if (result != null) {
results.add(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,16 @@ public String getErrorMessage() {
public boolean errorStructureFound() {
return errorStructureFound;
}

@Override
public boolean equals(Object o) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed these so the assertThat calls in the error parser tests can succeed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively could ErrorResponse be a record rather than an class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would make it cleaner. At the moment ErrorResponse is a base class and extended in a few places so we'd have to switch that to composition (I don't think we can extend a record 🤔 ).

if (o == null || getClass() != o.getClass()) return false;
ErrorResponse that = (ErrorResponse) o;
return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage);
}

@Override
public int hashCode() {
return Objects.hash(errorMessage, errorStructureFound);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom;

public class CustomServiceSettings {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all of these values are referenced but they will be once I pull in the rest of the custom service changes in future PRs.

public static final String NAME = "custom_service_settings";
public static final String URL = "url";
public static final String HEADERS = "headers";
public static final String REQUEST = "request";
public static final String REQUEST_CONTENT = "content";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom.response;

import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;

public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {

@Override
public InferenceServiceResults parse(HttpResult response) throws IOException {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
var map = jsonParser.map();

return transform(map);
}
}

protected abstract T transform(Map<String, Object> extractedField);

static List<?> validateList(Object obj, String fieldName) {
validateNonNull(obj, fieldName);

if (obj instanceof List<?> == false) {
throw new IllegalArgumentException(
Strings.format(
"Extracted field [%s] is an invalid type, expected a list but received [%s]",
fieldName,
obj.getClass().getSimpleName()
)
);
}

return (List<?>) obj;
}

static void validateNonNull(Object obj, String fieldName) {
Objects.requireNonNull(obj, Strings.format("Failed to parse field [%s], extracted field was null", fieldName));
}

static Map<String, Object> validateMap(Object obj, String fieldName) {
validateNonNull(obj, fieldName);

if (obj instanceof Map<?, ?> == false) {
throw new IllegalArgumentException(
Strings.format(
"Extracted field [%s] is an invalid type, expected a map but received [%s]",
fieldName,
obj.getClass().getSimpleName()
)
);
}

var keys = ((Map<?, ?>) obj).keySet();
for (var key : keys) {
if (key instanceof String == false) {
throw new IllegalStateException(
Strings.format(
"Extracted field [%s] map has an invalid key type. Expected a string but received [%s]",
fieldName,
key.getClass().getSimpleName()
)
);
}
}

@SuppressWarnings("unchecked")
var result = (Map<String, Object>) obj;
return result;
}

static List<Float> convertToListOfFloats(Object obj, String fieldName) {
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toFloat, fieldName);
}

static Float toFloat(Object obj, String fieldName) {
return toNumber(obj, fieldName).floatValue();
}

private static Number toNumber(Object obj, String fieldName) {
if (obj instanceof Number == false) {
throw new IllegalArgumentException(
Strings.format("Unable to convert field [%s] of type [%s] to Number", fieldName, obj.getClass().getSimpleName())
);
}

return ((Number) obj);
}

static List<Integer> convertToListOfIntegers(Object obj, String fieldName) {
return castList(validateList(obj, fieldName), BaseCustomResponseParser::toInteger, fieldName);
}

private static Integer toInteger(Object obj, String fieldName) {
return toNumber(obj, fieldName).intValue();
}

static <T> List<T> castList(List<?> items, BiFunction<Object, String, T> converter, String fieldName) {
validateNonNull(items, fieldName);

List<T> resultList = new ArrayList<>();
for (int i = 0; i < items.size(); i++) {
try {
resultList.add(converter.apply(items.get(i), fieldName));
} catch (Exception e) {
throw new IllegalStateException(Strings.format("Failed to parse list entry [%d], error: %s", i, e.getMessage()), e);
}
}

return resultList;
}

static <T> T toType(Object obj, Class<T> type, String fieldName) {
validateNonNull(obj, fieldName);

if (type.isInstance(obj) == false) {
throw new IllegalArgumentException(
Strings.format(
"Unable to convert field [%s] of type [%s] to [%s]",
fieldName,
obj.getClass().getSimpleName(),
type.getSimpleName()
)
);
}

return type.cast(obj);
}
}
Loading