Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,16 @@ public String getErrorMessage() {
public boolean errorStructureFound() {
return errorStructureFound;
}

@Override
public boolean equals(Object o) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed these so the assertThat calls in the error parser tests can succeed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively could ErrorResponse be a record rather than an class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would make it cleaner. At the moment ErrorResponse is a base class and extended in a few places so we'd have to switch that to composition (I don't think we can extend a record 🤔 ).

if (o == null || getClass() != o.getClass()) return false;
ErrorResponse that = (ErrorResponse) o;
return errorStructureFound == that.errorStructureFound && Objects.equals(errorMessage, that.errorMessage);
}

@Override
public int hashCode() {
return Objects.hash(errorMessage, errorStructureFound);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom;

public class CustomServiceSettings {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all of these values are referenced but they will be once I pull in the rest of the custom service changes in future PRs.

public static final String NAME = "custom_service_settings";
public static final String URL = "url";
public static final String HEADERS = "headers";
public static final String REQUEST = "request";
public static final String REQUEST_CONTENT = "content";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom.response;

import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

public abstract class BaseCustomResponseParser<T extends InferenceServiceResults> implements CustomResponseParser {

@Override
public InferenceServiceResults parse(HttpResult response) throws IOException {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
var map = jsonParser.map();

return transform(map);
}
}

protected abstract T transform(Map<String, Object> extractedField);

static List<?> validateList(Object obj) {
validateNonNull(obj);

if (obj instanceof List<?> == false) {
throw new IllegalArgumentException(
Strings.format("Extracted field is an invalid type, expected a list but received [%s]", obj.getClass().getSimpleName())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to include the name of the field in this error and the others e.g

Extracted field [embedding] is an invalid type, expected a list but received [%s]

);
}

return (List<?>) obj;
}

static void validateNonNull(Object obj) {
Objects.requireNonNull(obj, "Failed to parse response, extracted field was null");
}

static Map<String, Object> validateMap(Object obj) {
validateNonNull(obj);

if (obj instanceof Map<?, ?> == false) {
throw new IllegalArgumentException(
Strings.format("Extracted field is an invalid type, expected a map but received [%s]", obj.getClass().getSimpleName())
);
}

var keys = ((Map<?, ?>) obj).keySet();
for (var key : keys) {
if (key instanceof String == false) {
throw new IllegalStateException(
Strings.format(
"Extracted map has an invalid key type. Expected a string but received [%s]",
key.getClass().getSimpleName()
)
);
}
}

@SuppressWarnings("unchecked")
var result = (Map<String, Object>) obj;
return result;
}

static List<Float> convertToListOfFloats(Object obj) {
return validateAndCastList(validateList(obj), BaseCustomResponseParser::toFloat);
}

static Float toFloat(Object obj) {
return toNumber(obj).floatValue();
}

private static Number toNumber(Object obj) {
if (obj instanceof Number == false) {
throw new IllegalArgumentException(Strings.format("Unable to convert type [%s] to Number", obj.getClass().getSimpleName()));
}

return ((Number) obj);
}

static List<Integer> convertToListOfIntegers(Object obj) {
return validateAndCastList(validateList(obj), BaseCustomResponseParser::toInteger);
}

private static Integer toInteger(Object obj) {
return toNumber(obj).intValue();
}

static <T> List<T> validateAndCastList(List<?> items, Function<Object, T> converter) {
validateNonNull(items);

List<T> resultList = new ArrayList<>();
for (var obj : items) {
resultList.add(converter.apply(obj));
}

return resultList;
}

static <T> T toType(Object obj, Class<T> type) {
validateNonNull(obj);

if (type.isInstance(obj) == false) {
throw new IllegalArgumentException(
Strings.format("Unable to convert object of type [%s] to type [%s]", obj.getClass().getSimpleName(), type.getSimpleName())
);
}

return type.cast(obj);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom.response;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.common.MapPathExtractor;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;

public class CompletionResponseParser extends BaseCustomResponseParser<ChatCompletionResults> {

public static final String NAME = "completion_response_parser";
public static final String COMPLETION_PARSER_RESULT = "completion_result";

private final String completionResultPath;

public static CompletionResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
var path = extractRequiredString(responseParserMap, COMPLETION_PARSER_RESULT, JSON_PARSER, validationException);

if (path == null) {
throw validationException;
}

return new CompletionResponseParser(path);
}

public CompletionResponseParser(String completionResultPath) {
this.completionResultPath = Objects.requireNonNull(completionResultPath);
}

public CompletionResponseParser(StreamInput in) throws IOException {
this.completionResultPath = in.readString();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(completionResultPath);
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(JSON_PARSER);
{
builder.field(COMPLETION_PARSER_RESULT, completionResultPath);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CompletionResponseParser that = (CompletionResponseParser) o;
return Objects.equals(completionResultPath, that.completionResultPath);
}

@Override
public int hashCode() {
return Objects.hash(completionResultPath);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public ChatCompletionResults transform(Map<String, Object> map) {
var extractedField = MapPathExtractor.extract(map, completionResultPath);

validateNonNull(extractedField);

if (extractedField instanceof List<?> extractedList) {
var completionList = validateAndCastList(extractedList, (obj) -> toType(obj, String.class));
return new ChatCompletionResults(completionList.stream().map(ChatCompletionResults.Result::new).toList());
} else if (extractedField instanceof String extractedString) {
return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(extractedString)));
} else {
throw new IllegalArgumentException(
Strings.format(
"Extracted field is an invalid type, expected a list or a string but received [%s]",
extractedField.getClass().getSimpleName()
)
);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom.response;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xpack.inference.external.http.HttpResult;

import java.io.IOException;

public interface CustomResponseParser extends ToXContentFragment, NamedWriteable {
Copy link
Contributor Author

@jonathan-buttner jonathan-buttner Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response parsers need to be a named writeable because the service settings will use a factory to construct the appropriate response parser. The response parser is chosen based on the task type of the service which is provided in the PUT request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll register these in a future PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we plan to handle streaming, since HttpResult and StreamingHttpResult are different? We could maybe have CustomerResponseParser<T> and change the current implementation to BaseCustomResponseParser<T> implements CustomResponseParser<HttpResult> so we can make a StreamingBaseCustomResponseParser<T> implements CustomResponseParser<StreamingHttpResult>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point, I hadn't really thought about that. I think you're suggestion will work well though.

InferenceServiceResults parse(HttpResult response) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.custom.response;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.ERROR_PARSER;
import static org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser.toType;

public class ErrorResponseParser implements ToXContentFragment, Function<HttpResult, ErrorResponse> {

public static final String MESSAGE_PATH = "path";

private final String messagePath;

public static ErrorResponseParser fromMap(Map<String, Object> responseParserMap, ValidationException validationException) {
var path = extractRequiredString(responseParserMap, MESSAGE_PATH, ERROR_PARSER, validationException);

if (path == null) {
throw validationException;
}

return new ErrorResponseParser(path);
}

public ErrorResponseParser(String messagePath) {
this.messagePath = Objects.requireNonNull(messagePath);
}

public ErrorResponseParser(StreamInput in) throws IOException {
this.messagePath = in.readString();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(messagePath);
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(ERROR_PARSER);
{
builder.field(MESSAGE_PATH, messagePath);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ErrorResponseParser that = (ErrorResponseParser) o;
return Objects.equals(messagePath, that.messagePath);
}

@Override
public int hashCode() {
return Objects.hash(messagePath);
}

@Override
public ErrorResponse apply(HttpResult httpResult) {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, httpResult.body())
) {
var map = jsonParser.map();

// NOTE: This deviates from what we've done in the past. In the ErrorMessageResponseEntity logic
// if we find the top level error field we'll return a response with an empty message but indicate
// that we found the structure of the error object. Here if we're missing the final field we will return
// a ErrorResponse.UNDEFINED_ERROR with will indicate that we did not find the structure even if for example
// the outer error field does exist, but it doesn't contain the nested field we were looking for.
// If in the future we want the previous behavior, we can add a new message_path field or something and have
// the current path field point to the field that indicates whether we found an error object.
var errorText = toType(MapPathExtractor.extract(map, messagePath), String.class);
return new ErrorResponse(errorText);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}
}
Loading