Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
public static final String URL = "url";
public static final String HEADERS = "headers";
public static final String REQUEST = "request";
public static final String REQUEST_CONTENT = "content";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";
Expand All @@ -83,14 +82,7 @@ public static CustomServiceSettings fromMap(
removeNullValues(headers);
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false);

Map<String, Object> requestBodyMap = extractRequiredMap(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException);

String requestContentString = extractRequiredString(
Objects.requireNonNullElse(requestBodyMap, new HashMap<>()),
REQUEST_CONTENT,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
String requestContentString = extractRequiredString(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException);

Map<String, Object> responseParserMap = extractRequiredMap(
map,
Expand Down Expand Up @@ -125,11 +117,10 @@ public static CustomServiceSettings fromMap(
context
);

if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
throw validationException;
}

throwIfNotEmptyMap(requestBodyMap, REQUEST, NAME);
throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);
Expand Down Expand Up @@ -335,11 +326,7 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder

queryParameters.toXContent(builder, params);

builder.startObject(REQUEST);
{
builder.field(REQUEST_CONTENT, requestContentString);
}
builder.endObject();
builder.field(REQUEST, requestContentString);

builder.startObject(RESPONSE);
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import java.util.Objects;

import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_CONTENT;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST;
import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL;

public class CustomRequest implements Request {
Expand Down Expand Up @@ -133,7 +133,7 @@ private void setHeaders(HttpRequestBase httpRequest) {
private void setRequestContent(HttpPost httpRequest) {
String replacedRequestContentString = jsonPlaceholderReplacer.replace(
model.getServiceSettings().getRequestContentString(),
REQUEST_CONTENT
REQUEST
);
StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8);
httpRequest.setEntity(stringEntity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
* To use this class, extend it and pass the constructor a configuration.
* </p>
*/
public abstract class AbstractServiceTests extends ESTestCase {
public abstract class AbstractInferenceServiceTests extends ESTestCase {

protected final MockWebServer webServer = new MockWebServer();
protected ThreadPool threadPool;
Expand All @@ -80,7 +80,7 @@ public void tearDown() throws Exception {

private final TestConfiguration testConfiguration;

public AbstractServiceTests(TestConfiguration testConfiguration) {
public AbstractInferenceServiceTests(TestConfiguration testConfiguration) {
this.testConfiguration = Objects.requireNonNull(testConfiguration);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testFromMap() {
QueryParameters.QUERY_PARAMETERS,
queryParameters,
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -179,7 +179,7 @@ public void testFromMap_WithOptionalsNotSpecified() {
CustomServiceSettings.URL,
url,
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -243,7 +243,7 @@ public void testFromMap_RemovesNullValues_FromMaps() {
CustomServiceSettings.HEADERS,
headersWithNulls,
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -304,7 +304,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", 1)),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -353,7 +353,7 @@ public void testFromMap_ReturnsError_IfQueryParamsContainsNonStringValues() {
QueryParameters.QUERY_PARAMETERS,
List.of(List.of("key", 1)),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -393,7 +393,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
"invalid_request",
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand All @@ -413,13 +413,7 @@ public void testFromMap_ReturnsError_IfRequestMapIsMissing() {
() -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id")
);

assertThat(
exception.getMessage(),
is(
"Validation Failed: 1: [service_settings] does not contain the required setting [request];"
+ "2: [service_settings] does not contain the required setting [content];"
)
);
assertThat(exception.getMessage(), is("Validation Failed: 1: [service_settings] does not contain the required setting [request];"));
}

public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
Expand All @@ -433,7 +427,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
"invalid_response",
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -464,46 +458,6 @@ public void testFromMap_ReturnsError_IfResponseMapIsMissing() {
);
}

public void testFromMap_ReturnsError_IfRequestMapIsNotEmptyAfterParsing() {
String url = "http://www.abc.com";
String requestContentString = "request body";

var mapSettings = new HashMap<String, Object>(
Map.of(
CustomServiceSettings.URL,
url,
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString, "key", "value")),
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
CustomServiceSettings.JSON_PARSER,
new HashMap<>(
Map.of(TextEmbeddingResponseParser.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
),
CustomServiceSettings.ERROR_PARSER,
new HashMap<>(Map.of(ErrorResponseParser.MESSAGE_PATH, "$.error.message"))
)
)
)
);

var exception = expectThrows(
ElasticsearchStatusException.class,
() -> CustomServiceSettings.fromMap(mapSettings, ConfigurationParseContext.REQUEST, TaskType.TEXT_EMBEDDING, "inference_id")
);

assertThat(
exception.getMessage(),
is(
"Configuration contains unknown settings [{key=value}] while parsing field [request]"
+ " for settings [custom_service_settings]"
)
);
}

public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() {
String url = "http://www.abc.com";
String requestContentString = "request body";
Expand All @@ -515,7 +469,7 @@ public void testFromMap_ReturnsError_IfJsonParserMapIsNotEmptyAfterParsing() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -560,7 +514,7 @@ public void testFromMap_ReturnsError_IfResponseMapIsNotEmptyAfterParsing() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -602,7 +556,7 @@ public void testFromMap_ReturnsError_IfErrorParserMapIsNotEmptyAfterParsing() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -642,7 +596,7 @@ public void testFromMap_ReturnsError_IfTaskTypeIsInvalid() {
CustomServiceSettings.HEADERS,
new HashMap<>(Map.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, requestContentString)),
requestContentString,
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down Expand Up @@ -687,9 +641,7 @@ public void testXContent() throws IOException {
"headers": {
"key": "value"
},
"request": {
"content": "string"
},
"request": "string",
"response": {
"json_parser": {
"text_embeddings": "$.result.embeddings[*].embedding"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.services.AbstractServiceTests;
import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
Expand Down Expand Up @@ -54,7 +54,7 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;

public class CustomServiceTests extends AbstractServiceTests {
public class CustomServiceTests extends AbstractInferenceServiceTests {

public CustomServiceTests() {
super(createTestConfiguration());
Expand Down Expand Up @@ -150,7 +150,7 @@ private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
QueryParameters.QUERY_PARAMETERS,
List.of(List.of("key", "value")),
CustomServiceSettings.REQUEST,
new HashMap<>(Map.of(CustomServiceSettings.REQUEST_CONTENT, "request body")),
"request body",
CustomServiceSettings.RESPONSE,
new HashMap<>(
Map.of(
Expand Down