diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml
index 010ac5528f46e..fc9f4ccfb63f4 100644
--- a/gradle/verification-metadata.xml
+++ b/gradle/verification-metadata.xml
@@ -11,6 +11,7 @@
+
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 4d94ff717f205..c3c34636d70b8 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -185,6 +185,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM = def(9_023_0_00);
public static final TransportVersion MAX_OPERATION_SIZE_REJECTIONS_ADDED = def(9_024_0_00);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00);
+ public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_026_0_00);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java
index 73a0e3cc8a774..abac5eee3ae12 100644
--- a/server/src/main/java/org/elasticsearch/inference/TaskType.java
+++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java
@@ -25,6 +25,12 @@ public enum TaskType implements Writeable {
SPARSE_EMBEDDING,
RERANK,
COMPLETION,
+ CUSTOM {
+ @Override
+ public boolean isAnyOrSame(TaskType other) {
+ return true;
+ }
+ },
ANY {
@Override
public boolean isAnyOrSame(TaskType other) {
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
index dc177795af76a..cce33232724ea 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
@@ -199,6 +199,14 @@ public ActionRequestValidationException validate() {
}
}
+ if (taskType.equals(TaskType.CUSTOM)) {
+ if (query == null) {
+ var e = new ActionRequestValidationException();
+ e.addValidationError(format("Field [query] cannot be null for task type [%s]", TaskType.CUSTOM));
+ return e;
+ }
+ }
+
return null;
}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java
new file mode 100644
index 0000000000000..f91a4c8cb9fdd
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xpack.core.ml.inference.results.CustomResults;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+
+public class CustomServiceResults implements InferenceServiceResults {
+ public static final String NAME = "custom_service_results";
+ public static final String CUSTOM_TYPE = TaskType.CUSTOM.name().toLowerCase(Locale.ROOT);
+
+ Map data;
+
+ public CustomServiceResults(Map data) {
+ this.data = data;
+ }
+
+ public CustomServiceResults(StreamInput in) throws IOException {
+ this.data = in.readGenericMap();
+ }
+
+ @Override
+ public Iterator extends ToXContent> toXContentChunked(ToXContent.Params params) {
+ return ChunkedToXContentHelper.object(CUSTOM_TYPE, this.asMap());
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeGenericMap(data);
+ }
+
+ @Override
+ public List extends InferenceResults> transformToCoordinationFormat() {
+ return transformToLegacyFormat();
+ }
+
+ @Override
+ public List extends InferenceResults> transformToLegacyFormat() {
+ return List.of(new CustomResults(data));
+ }
+
+ @Override
+ public Map asMap() {
+ return data;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(NAME);
+ sb.append(Integer.toHexString(hashCode()));
+ sb.append("\n");
+ sb.append(this.asMap().toString());
+ return sb.toString();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CustomServiceResults that = (CustomServiceResults) o;
+ return data.equals(that.data);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(data);
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java
new file mode 100644
index 0000000000000..f91edcdebec9a
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+public class CustomResults implements InferenceResults {
+ public static final String NAME = "custom_results";
+ public static final String CUSTOM_TYPE = TaskType.CUSTOM.toString();
+
+ Map data;
+
+ public CustomResults(Map data) {
+ this.data = data;
+ }
+
+ public CustomResults(StreamInput in) throws IOException {
+ this.data = in.readGenericMap();
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.field(CUSTOM_TYPE, this.asMap());
+ return builder;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeGenericMap(data);
+ }
+
+ @Override
+ public String getResultsField() {
+ return CUSTOM_TYPE;
+ }
+
+ @Override
+ public Map asMap() {
+ return data;
+ }
+
+ @Override
+ public Map asMap(String outputField) {
+ Map map = new LinkedHashMap<>();
+ map.put(outputField, this.asMap());
+ return map;
+ }
+
+ @Override
+ public Map predictedValue() {
+ return this.asMap();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CustomResults that = (CustomResults) o;
+ return data.equals(that.data);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(data);
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
index e9f4df7a523ad..0ed6007c3a783 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
@@ -102,6 +102,21 @@ public void testValidation_Rerank() {
assertNull(e);
}
+ public void testValidation_Custom() {
+ InferenceAction.Request request = new InferenceAction.Request(
+ TaskType.CUSTOM,
+ "model",
+ "query",
+ List.of("input"),
+ null,
+ null,
+ null,
+ false
+ );
+ ActionRequestValidationException e = request.validate();
+ assertNull(e);
+ }
+
public void testValidation_TextEmbedding_Null() {
InferenceAction.Request inputNullRequest = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
@@ -166,6 +181,37 @@ public void testValidation_Rerank_Empty() {
assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];"));
}
+ public void testValidation_Custom_Null() {
+ InferenceAction.Request queryNullRequest = new InferenceAction.Request(
+ TaskType.CUSTOM,
+ "model",
+ null,
+ List.of("input"),
+ null,
+ null,
+ null,
+ false
+ );
+ ActionRequestValidationException queryNullError = queryNullRequest.validate();
+ assertNotNull(queryNullError);
+ assertThat(queryNullError.getMessage(), is("Validation Failed: 1: Field [query] cannot be null for task type [custom];"));
+ }
+
+ public void testValidation_Custom_Empty() {
+ InferenceAction.Request queryNullRequest = new InferenceAction.Request(
+ TaskType.CUSTOM,
+ "model",
+ "",
+ List.of("input"),
+ null,
+ null,
+ null,
+ false
+ );
+ ActionRequestValidationException e = queryNullRequest.validate();
+ assertNull(e);
+ }
+
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
String singleInputRequest = """
{
diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle
index cb3ccbd171304..9ce4609f254a3 100644
--- a/x-pack/plugin/inference/build.gradle
+++ b/x-pack/plugin/inference/build.gradle
@@ -38,6 +38,7 @@ dependencies {
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
api "com.ibm.icu:icu4j:${versions.icu4j}"
+ api "org.apache.commons:commons-lang3:${versions.commons_lang3}"
runtimeOnly 'com.google.guava:guava:32.0.1-jre'
implementation 'com.google.code.gson:gson:2.10'
@@ -57,6 +58,11 @@ dependencies {
implementation 'io.grpc:grpc-context:1.49.2'
implementation 'io.opencensus:opencensus-api:0.31.1'
implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1'
+ implementation 'org.apache.commons:commons-text:1.4'
+ implementation 'com.jayway.jsonpath:json-path:2.9.0'
+ implementation 'net.minidev:json-smart:2.5.2'
+ implementation 'net.minidev:accessors-smart:2.5.2'
+
/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java
index 78f30e7da0670..11dd5997765da 100644
--- a/x-pack/plugin/inference/src/main/java/module-info.java
+++ b/x-pack/plugin/inference/src/main/java/module-info.java
@@ -35,6 +35,10 @@
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
+ requires org.apache.commons.text;
+ requires json.path;
+ requires unboundid.ldapsdk;
+ requires json.smart;
exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
index d57a6b86e4e71..61272dcd95635 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
@@ -58,6 +58,8 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
+import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
+import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
@@ -149,6 +151,7 @@ public static List getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
+ addCustomWriteables(namedWriteables);
addUnifiedNamedWriteables(namedWriteables);
@@ -663,4 +666,11 @@ private static void addEisNamedWriteables(List nam
)
);
}
+
+ private static void addCustomWriteables(List namedWriteables) {
+ namedWriteables.add(
+ new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
+ );
+ namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
+ }
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
index 714829c08b041..5d8eeb5f2bad7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@@ -115,6 +115,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
+import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@@ -358,6 +359,7 @@ public List getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
+ context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java
new file mode 100644
index 0000000000000..4754c321f27b4
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.custom;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.CustomRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class CustomAction implements ExecutableAction {
+ private final CustomModel model;
+ private final String failedToSendRequestErrorMessage;
+ private final Sender sender;
+ private final CustomRequestManager requestManager;
+
+ public CustomAction(Sender sender, CustomModel model, ServiceComponents serviceComponents) {
+ this.model = Objects.requireNonNull(model);
+ this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Custom Search");
+ this.sender = Objects.requireNonNull(sender);
+ this.requestManager = CustomRequestManager.of(model, serviceComponents.threadPool());
+ }
+
+ @Override
+ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) {
+ try {
+ ActionListener wrappedListener = wrapFailuresInElasticsearchException(
+ failedToSendRequestErrorMessage,
+ listener
+ );
+ sender.send(requestManager, inferenceInputs, timeout, wrappedListener);
+ } catch (ElasticsearchException e) {
+ listener.onFailure(e);
+ } catch (Exception e) {
+ listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java
new file mode 100644
index 0000000000000..55c5cbfd0f731
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.custom;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the custom model type.
+ */
+public class CustomActionCreator implements CustomActionVisitor {
+ private final Sender sender;
+ private final ServiceComponents serviceComponents;
+
+ public CustomActionCreator(Sender sender, ServiceComponents serviceComponents) {
+ this.sender = Objects.requireNonNull(sender);
+ this.serviceComponents = Objects.requireNonNull(serviceComponents);
+ }
+
+ @Override
+ public ExecutableAction create(CustomModel model, Map taskSettings, InputType inputType) {
+ var overriddenModel = CustomModel.of(model, taskSettings);
+
+ return new CustomAction(sender, overriddenModel, serviceComponents);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java
new file mode 100644
index 0000000000000..2c66e008671b3
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.custom;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+
+import java.util.Map;
+
+public interface CustomActionVisitor {
+ ExecutableAction create(CustomModel model, Map taskSettings, InputType inputType);
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java
new file mode 100644
index 0000000000000..fb22ac0832ff8
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.custom;
+
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.response.custom.CustomErrorResponseEntity;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+
+import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
+
+/**
+ * Defines how to handle various errors returned from the custom integration.
+ */
+public class CustomResponseHandler extends BaseResponseHandler {
+ public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
+ super(requestType, parseFunction, CustomErrorResponseEntity::fromResponse);
+ }
+
+ @Override
+ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+ throws RetryException {
+ checkForFailureStatusCode(request, result);
+ checkForEmptyBody(throttlerManager, logger, request, result);
+ }
+
+ /**
+ * Validates the status code throws an RetryException if not in the range [200, 300).
+ *
+ * @param request The http request
+ * @param result The http response and body
+ * @throws RetryException Throws if status code is {@code >= 300 or < 200 }
+ */
+ @Override
+ protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+ if (statusCode >= 200 && statusCode < 300) {
+ return;
+ }
+
+ // handle error codes
+ if (statusCode >= 500) {
+ throw new RetryException(false, buildError(SERVER_ERROR, request, result));
+ } else if (statusCode == 429) {
+ throw new RetryException(true, buildError(RATE_LIMIT, request, result));
+ } else if (statusCode == 401) {
+ throw new RetryException(false, buildError(AUTHENTICATION, request, result));
+ } else if (statusCode >= 300 && statusCode < 400) {
+ throw new RetryException(false, buildError(REDIRECTION, request, result));
+ } else {
+ throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java
new file mode 100644
index 0000000000000..da76425c93e98
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.custom.CustomResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.custom.CustomRequest;
+import org.elasticsearch.xpack.inference.external.response.custom.CustomResponseEntity;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Supplier;
+
+public class CustomRequestManager extends BaseRequestManager {
+ private static final Logger logger = LogManager.getLogger(CustomRequestManager.class);
+
+ private static final ResponseHandler HANDLER = createCustomHandler();
+
+ record RateLimitGrouping(int apiKeyHash) {
+ public static RateLimitGrouping of(CustomModel model) {
+ Objects.requireNonNull(model);
+
+ return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode());
+ }
+ }
+
+ private static ResponseHandler createCustomHandler() {
+ return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
+ }
+
+ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
+ return new CustomRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
+ }
+
+ private final CustomModel model;
+
+ private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
+ super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
+ this.model = model;
+ }
+
+ @Override
+ public void execute(
+ InferenceInputs inferenceInputs,
+ RequestSender requestSender,
+ Supplier hasRequestCompletedFunction,
+ ActionListener listener
+ ) {
+ String query;
+ List input;
+ if (inferenceInputs instanceof QueryAndDocsInputs) {
+ QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs);
+ query = queryAndDocsInputs.getQuery();
+ input = queryAndDocsInputs.getChunks();
+ } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) {
+ query = null;
+ input = chatInputs.getInputs();
+ } else if (inferenceInputs instanceof DocumentsOnlyInput) {
+ DocumentsOnlyInput docsInputs = DocumentsOnlyInput.of(inferenceInputs);
+ query = null;
+ input = docsInputs.getInputs();
+ } else {
+ throw InferenceInputs.createUnsupportedTypeException(inferenceInputs, InferenceInputs.class);
+ }
+ CustomRequest request = new CustomRequest(query, input, model);
+ execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java
new file mode 100644
index 0000000000000..838dc561c6399
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java
@@ -0,0 +1,271 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.custom;
+
+import com.google.gson.Gson;
+
+import org.apache.commons.text.StringSubstitutor;
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.methods.HttpPut;
+import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.http.entity.StringEntity;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
+import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.security.AccessController;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_FORMAT_JSON;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_FORMAT_STRING;
+
+public class CustomRequest implements Request {
+ private final long startTime;
+
+ public static final Gson gson;
+ static {
+ gson = new Gson();
+ }
+
+ private final String QUERY = "query";
+ private final String INPUT = "input";
+
+ private final Map customParams;
+
+ private final CustomServiceSettings serviceSettings;
+ private final CustomTaskSettings taskSettings;
+ private final Map secretParameters;
+ private final String url;
+ private final String path;
+ private final String method;
+ private final String queryString;
+ private final Map headers;
+ private final String requestFormat;
+ private final Map requestContent;
+ private final String requestContentString;
+ private final URI uri;
+ StringSubstitutor substitutor;
+ private final String inferenceEntityId;
+
+ public CustomRequest(String query, List input, CustomModel model) {
+ this.startTime = System.currentTimeMillis();
+ Objects.requireNonNull(model);
+
+ serviceSettings = model.getServiceSettings();
+ taskSettings = model.getTaskSettings();
+ secretParameters = model.getSecretSettings().getSecretParameters();
+ path = serviceSettings.getPath();
+ method = serviceSettings.getMethod().toUpperCase(Locale.ROOT);
+ queryString = serviceSettings.getQueryString();
+ headers = serviceSettings.getHeaders();
+ requestFormat = serviceSettings.getRequestFormat();
+ requestContent = serviceSettings.getRequestContent();
+ requestContentString = serviceSettings.getRequestContentString();
+ url = model.getServiceSettings().getUrl();
+
+ Map customParamsObjectMap = new HashMap<>();
+ if (secretParameters != null) {
+ for (String key : secretParameters.keySet()) {
+ Object paramValue = secretParameters.get(key);
+ if (paramValue instanceof SecureString) {
+ customParamsObjectMap.put(key, ((SecureString) paramValue).toString());
+ } else {
+ customParamsObjectMap.put(key, paramValue);
+ }
+ }
+ }
+
+ customParams = new HashMap();
+ if (taskSettings.getParameters() != null && taskSettings.getParameters().isEmpty() == false) {
+ Map taskParams = getParameterMap(taskSettings.getParameters());
+ for (String key : taskParams.keySet()) {
+ customParams.put(key, taskParams.get(key));
+ }
+ }
+
+ // if user's custom parameters contain input and query, it will be replaced by inference's input and query
+ if (query != null) {
+ customParamsObjectMap.put(QUERY, query);
+ }
+
+ String serviceType = serviceSettings.getServiceType();
+ TaskType taskType = TaskType.fromStringOrStatusException(serviceType);
+ if (taskType.equals(TaskType.COMPLETION)) {
+ if (input.size() == 1) {
+ customParamsObjectMap.put(INPUT, input.get(0));
+ } else {
+ customParamsObjectMap.put(INPUT, input);
+ }
+ } else {
+ customParamsObjectMap.put(INPUT, input);
+ }
+ customParams.putAll(getParameterMap(customParamsObjectMap));
+
+ substitutor = new StringSubstitutor(customParams, "${", "}");
+
+ uri = buildUri();
+ inferenceEntityId = model.getInferenceEntityId();
+ }
+
+ @Override
+ public HttpRequest createHttpRequest() {
+ HttpRequestBase httpRequest;
+ if (method.equalsIgnoreCase(HttpGet.METHOD_NAME)) {
+ httpRequest = new HttpGet(uri);
+ } else if (method.equalsIgnoreCase(HttpPost.METHOD_NAME)) {
+ httpRequest = new HttpPost(uri);
+ } else if (method.equalsIgnoreCase(HttpPut.METHOD_NAME)) {
+ httpRequest = new HttpPut(uri);
+ } else {
+ throw new IllegalArgumentException("unsupported http method [" + method + "], support GET, PUT and POST");
+ }
+
+ setHeaders(httpRequest);
+ setRequestContent(httpRequest);
+
+ return new HttpRequest(httpRequest, getInferenceEntityId());
+ }
+
+ private void setHeaders(HttpRequestBase httpRequest) {
+ // Header content_type's default value, if user defines the Content-Type, it will be replaced by user's value;
+ httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+
+ if (headers != null && headers.isEmpty() == false) {
+ for (String key : headers.keySet()) {
+ String headersValue = (String) headers.get(key);
+ String replacedHeadersValue = substitutor.replace(headersValue);
+ placeholderValidation(replacedHeadersValue);
+ httpRequest.setHeader(key, replacedHeadersValue);
+ }
+ }
+ }
+
+ private void setRequestContent(HttpRequestBase httpRequest) {
+ switch (requestFormat.toLowerCase()) {
+ case REQUEST_FORMAT_JSON: {
+ // todo: support json format request
+ break;
+ }
+ case REQUEST_FORMAT_STRING: {
+ if (requestContentString != null && (method.equals(HttpPost.METHOD_NAME) || method.equals(HttpPut.METHOD_NAME))) {
+ String replacedRequestContentString = substitutor.replace(requestContentString);
+ placeholderValidation(replacedRequestContentString);
+ StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8);
+ if (httpRequest instanceof HttpPost) {
+ ((HttpPost) httpRequest).setEntity(stringEntity);
+ } else if (httpRequest instanceof HttpPut) {
+ ((HttpPut) httpRequest).setEntity(stringEntity);
+ }
+ }
+ break;
+ }
+ }
+
+ }
+
+ @Override
+ public String getInferenceEntityId() {
+ return inferenceEntityId;
+ }
+
+ @Override
+ public URI getURI() {
+ return uri;
+ }
+
+ public CustomServiceSettings getServiceSettings() {
+ return serviceSettings;
+ }
+
+ @Override
+ public Request truncate() {
+ return this;
+ }
+
+ @Override
+ public boolean[] getTruncationInfo() {
+ return null;
+ }
+
+ URI buildUri() {
+ try {
+ String uri = url + path;
+ if (queryString != null) {
+ String replacedQueryString = substitutor.replace(queryString);
+ placeholderValidation(replacedQueryString);
+ uri = uri + replacedQueryString;
+ }
+ return new URI(uri);
+ } catch (URISyntaxException e) {
+ // using bad request here so that potentially sensitive URL information does not get logged
+ throw new ElasticsearchStatusException(
+ Strings.format("Failed to construct %s URL", CustomUtils.SERVICE_NAME),
+ RestStatus.BAD_REQUEST,
+ e
+ );
+ }
+ }
+
+ @SuppressWarnings("removal")
+ public static Map getParameterMap(Map parameterObjs) {
+ Map parameters = new HashMap<>();
+ for (String key : parameterObjs.keySet()) {
+ Object value = parameterObjs.get(key);
+ try {
+ AccessController.doPrivileged((PrivilegedExceptionAction) () -> {
+ if (value instanceof String) {
+ parameters.put(key, (String) value);
+ } else {
+ parameters.put(key, gson.toJson(value));
+ }
+ return null;
+ });
+ } catch (PrivilegedActionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return parameters;
+ }
+
+ private void placeholderValidation(String s) throws IllegalArgumentException {
+ if (taskSettings.getIgnorePlaceholderCheck() != null && taskSettings.getIgnorePlaceholderCheck()) {
+ return;
+ }
+ String pattern = "\\$\\{.*?\\}";
+ Pattern compiledPattern = Pattern.compile(pattern);
+ Matcher matcher = compiledPattern.matcher(s);
+ if (matcher.find()) {
+ throw new IllegalArgumentException(String.format(Locale.ROOT, "placeholder is not replaced, find placeholder in [%s]", s));
+ }
+ }
+
+ public long getStartTime() {
+ return startTime;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java
new file mode 100644
index 0000000000000..6146cd1b7310e
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.custom;
+
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
+import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+public record CustomRequestEntity(CustomServiceSettings serviceSettings, CustomTaskSettings taskSettings) implements ToXContentObject {
+
+ public CustomRequestEntity {
+ Objects.requireNonNull(serviceSettings);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ Map requestContent = serviceSettings.getRequestContent();
+ Map taskSettingsContent = taskSettings.getParameters();
+ builder.startObject();
+ {
+ for (String key : requestContent.keySet()) {
+ builder.field(key, requestContent.get(key));
+ }
+ if (taskSettingsContent != null && taskSettingsContent.isEmpty() == false) {
+ for (String key : taskSettingsContent.keySet()) {
+ builder.field(key, taskSettingsContent.get(key));
+ }
+ }
+ }
+ builder.endObject();
+ return builder;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java
new file mode 100644
index 0000000000000..33cfcd239b1d2
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java
@@ -0,0 +1,12 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.custom;
+
+public class CustomUtils {
+ public static final String SERVICE_NAME = "custom-model";
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java
new file mode 100644
index 0000000000000..fca794f3a29f7
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.custom;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+
+import java.util.Locale;
+
+public class CustomErrorResponseEntity extends ErrorResponse {
+ private static final Logger logger = LogManager.getLogger(CustomErrorResponseEntity.class);
+
+ private CustomErrorResponseEntity(String errorMessage) {
+ super(errorMessage);
+ }
+
+ public static ErrorResponse fromResponse(HttpResult response) {
+ try (
+ XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
+ .createParser(XContentParserConfiguration.EMPTY, response.body())
+ ) {
+ var responseMap = jsonParser.map();
+ if (logger.isDebugEnabled()) {
+ logger.debug("Received a server error response: {}, body:{}", response.response(), responseMap.toString());
+ }
+
+ return new ErrorResponse(
+ String.format(Locale.ROOT, "Received a server error response: %s, body: %s", response.response(), responseMap.toString())
+ );
+ } catch (Exception e) {
+ logger.info("Parsing custom response body failed. Response: {}", response.response(), e);
+ return new ErrorResponse(String.format(Locale.ROOT, "Parsing custom response body failed. Response: %s", response.response()));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java
new file mode 100644
index 0000000000000..7c55cf9a09c27
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java
@@ -0,0 +1,268 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.custom;
+
+import net.minidev.json.JSONArray;
+
+import com.jayway.jsonpath.JsonPath;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.CustomServiceResults;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.request.custom.CustomRequest;
+import org.elasticsearch.xpack.inference.services.custom.ResponseJsonParser;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+
+public class CustomResponseEntity {
+ private static final Logger logger = LogManager.getLogger(CustomResponseEntity.class);
+
+ public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
+ CustomRequest customRequest = (CustomRequest) request;
+ String serviceType = customRequest.getServiceSettings().getServiceType();
+ TaskType taskType = TaskType.fromStringOrStatusException(serviceType);
+ ResponseJsonParser responseJsonParser = customRequest.getServiceSettings().getResponseJsonParser();
+
+ InferenceServiceResults result = switch (taskType) {
+ case TEXT_EMBEDDING -> fromTextEmbeddingResponse(response, responseJsonParser);
+ case SPARSE_EMBEDDING -> fromSparseEmbeddingResponse(response, responseJsonParser);
+ case RERANK -> fromRerankResponse(response, responseJsonParser);
+ case COMPLETION -> fromCompletionResponse(response, responseJsonParser);
+ case CUSTOM -> fromCustomResponse(response);
+ default -> throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType), RestStatus.BAD_REQUEST);
+ };
+
+ logger.debug(
+ "Ai Search uri [{}] response: client cost [{}ms]",
+ request.getURI() != null ? request.getURI() : "",
+ System.currentTimeMillis() - customRequest.getStartTime()
+ );
+
+ return result;
+ }
+
+ private static CustomServiceResults fromCustomResponse(HttpResult response) throws IOException {
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ return new CustomServiceResults(jsonParser.map());
+ }
+ }
+
+ private static InferenceServiceResults fromTextEmbeddingResponse(HttpResult response, ResponseJsonParser responseJsonParser) {
+ String embeddingPath = responseJsonParser.getTextEmbeddingsPath();
+
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ JSONArray embeddingResults = JsonPath.read(jsonParser.map(), embeddingPath);
+ if (embeddingResults == null || embeddingResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(
+ Locale.ROOT,
+ "can't parse text_embeddings results from response, please check the path [%s]",
+ embeddingPath
+ )
+ );
+ }
+
+ List embeddings = embeddingResults.stream()
+ .map(
+ embeddingResult -> ((List>) embeddingResult).stream()
+ .map(obj -> ((Number) obj).floatValue())
+ .collect(Collectors.toList())
+ )
+ .map(TextEmbeddingFloatResults.Embedding::of)
+ .collect(Collectors.toList());
+
+ return new TextEmbeddingFloatResults(embeddings);
+ } catch (Exception e) {
+ logger.error("failed to parse text_embeddings results from response:", e);
+ throw new IllegalArgumentException("failed to parse text_embeddings results from response:", e);
+ }
+ }
+
+ private static InferenceServiceResults fromSparseEmbeddingResponse(HttpResult response, ResponseJsonParser responseJsonParser) {
+ String resultPath = responseJsonParser.getSparseResultPath();
+ String tokenPath = responseJsonParser.getSparseTokenPath();
+ String weightPath = responseJsonParser.getSparseWeightPath();
+
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ Object jsonObject = jsonParser.map();
+ JSONArray sparseResults = JsonPath.read(jsonObject, resultPath);
+ if (sparseResults == null || sparseResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse sparse_result results from response, please check the path [%s]", resultPath)
+ );
+ }
+
+ List embeddingList = new ArrayList<>();
+
+ for (Object obj : sparseResults) {
+ JSONArray tokenResults = JsonPath.read(obj, tokenPath);
+ if (tokenResults == null || tokenResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(
+ Locale.ROOT,
+ "can't parse sparse embeddings results from response, please check the path [%s]",
+ tokenPath
+ )
+ );
+ }
+ JSONArray weightResults = JsonPath.read(obj, weightPath);
+ if (weightResults == null || weightResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(
+ Locale.ROOT,
+ "can't parse sparse embeddings results from response, please check the path [%s]",
+ weightPath
+ )
+ );
+ }
+ List tokens = tokenResults.stream().map(Object::toString).toList();
+ List weights = weightResults.stream().map(weight -> ((Number) weight).floatValue()).toList();
+
+ List weightedTokens = new ArrayList<>();
+ if (tokens.size() != weights.size()) {
+ throw new IllegalArgumentException("Tokens and weights size does not match");
+ }
+ for (int i = 0; i < tokens.size(); i++) {
+ weightedTokens.add(new WeightedToken(tokens.get(i), weights.get(i)));
+ }
+
+ embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false));
+ }
+
+ return new SparseEmbeddingResults(embeddingList);
+ } catch (Exception e) {
+ logger.error("failed to parse sparse_result results from response:", e);
+ throw new IllegalArgumentException("failed to parse sparse_result results from response:", e);
+ }
+ }
+
+ private static InferenceServiceResults fromRerankResponse(HttpResult response, ResponseJsonParser responseJsonParser) {
+ String indexPath = responseJsonParser.getRerankedIndexPath();
+ String scorePath = responseJsonParser.getRelevanceScorePath();
+ String docPath = responseJsonParser.getDocumentTextPath();
+
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ Object jsonObject = jsonParser.map();
+ JSONArray indexResults = indexPath != null ? JsonPath.read(jsonObject, indexPath) : null;
+ if (indexPath != null && (indexResults == null || indexResults.isEmpty())) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse reranked_index results from response, please check the path [%s]", indexPath)
+ );
+ }
+ JSONArray scoreResults = JsonPath.read(jsonObject, scorePath);
+ if (scoreResults == null || scoreResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse relevance_score results from response, please check the path [%s]", scorePath)
+ );
+ }
+ JSONArray docResults = docPath != null ? JsonPath.read(jsonObject, docPath) : null;
+ if (docPath != null && (docResults == null || docResults.isEmpty())) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse doc results from response, please check the path [%s]", docPath)
+ );
+ }
+
+ List indices = indexResults != null ? indexResults.stream().map(index -> (Integer) index).toList() : null;
+
+ List scores = scoreResults.stream().map(score -> ((Number) score).floatValue()).toList();
+
+ List docs = docResults != null ? docResults.stream().map(doc -> (String) doc).toList() : null;
+
+ List rerankResults = new ArrayList<>();
+ if (indices != null && indices.size() != scores.size()) {
+ throw new IllegalArgumentException("Indices and scores size does not match");
+ }
+ for (int i = 0; i < scores.size(); i++) {
+ if (docs != null) {
+ if (indices != null) {
+ rerankResults.add(new RankedDocsResults.RankedDoc(indices.get(i), scores.get(i), docs.get(i)));
+ } else {
+ rerankResults.add(new RankedDocsResults.RankedDoc(i, scores.get(i), docs.get(i)));
+ }
+ } else {
+ if (indices != null) {
+ rerankResults.add(new RankedDocsResults.RankedDoc(indices.get(i), scores.get(i), null));
+ } else {
+ rerankResults.add(new RankedDocsResults.RankedDoc(i, scores.get(i), null));
+ }
+ }
+ }
+
+ return new RankedDocsResults(rerankResults);
+ } catch (Exception e) {
+ logger.error("failed to parse rerank results from response:", e);
+ throw new IllegalArgumentException("failed to parse rerank results from response:", e);
+ }
+ }
+
+ private static InferenceServiceResults fromCompletionResponse(HttpResult response, ResponseJsonParser responseJsonParser) {
+ String resultPath = responseJsonParser.getCompletionResultPath();
+
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ Object results = JsonPath.read(jsonParser.map(), resultPath);
+ if (results == null) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse completion results from response, please check the path [%s]", resultPath)
+ );
+ }
+ if (results instanceof String) {
+ return new ChatCompletionResults(List.of(new ChatCompletionResults.Result((String) results)));
+ } else if (results instanceof JSONArray jsonArrayResults) {
+ if (jsonArrayResults.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "can't parse completion results from response, please check the path [%s]", resultPath)
+ );
+ }
+ List completionResults = jsonArrayResults.stream()
+ .map(obj -> (String) obj)
+ .map(ChatCompletionResults.Result::new)
+ .collect(Collectors.toList());
+ return new ChatCompletionResults(completionResults);
+ } else {
+ throw new IllegalArgumentException("Unsupported completion result type: " + results.getClass().getName());
+ }
+ } catch (Exception e) {
+ logger.error("failed to parse completion results from response:", e);
+ throw new IllegalArgumentException("failed to parse completion results from response:", e);
+ }
+ }
+
+ public static String unsupportedTaskTypeErrorMsg(TaskType taskType) {
+ return "ai search custom service does not support task type [" + taskType + "]";
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
index 56bf6c1359a56..589bd3aceed6c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
@@ -74,7 +74,7 @@ public void infer(
private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) {
return switch (model.getTaskType()) {
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
- case RERANK -> new QueryAndDocsInputs(query, input, stream);
+ case RERANK, CUSTOM -> new QueryAndDocsInputs(query, input, stream);
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
default -> throw new ElasticsearchStatusException(
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 7330d45b6f16c..2abc58d435cfb 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -254,6 +254,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}
+ public static String mustBeNonEmptyMap(String settingName, String scope) {
+ return Strings.format("[%s] Invalid value empty map. [%s] must be a non-empty map", scope, settingName);
+ }
+
public static String invalidTimeValueMsg(String timeValueStr, String settingName, String scope, String exceptionMsg) {
return Strings.format(
"[%s] Invalid time value [%s]. [%s] must be a valid time value string: %s",
@@ -427,6 +431,58 @@ public static Integer extractRequiredPositiveInteger(
return field;
}
+ @SuppressWarnings("unchecked")
+ public static Map extractRequiredMap(
+ Map map,
+ String settingName,
+ String scope,
+ ValidationException validationException
+ ) {
+ int initialValidationErrorCount = validationException.validationErrors().size();
+ Map requiredField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException);
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ if (requiredField == null) {
+ validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope));
+ } else if (requiredField.isEmpty()) {
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope));
+ }
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ return requiredField;
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Map extractOptionalMap(
+ Map map,
+ String settingName,
+ String scope,
+ ValidationException validationException
+ ) {
+ int initialValidationErrorCount = validationException.validationErrors().size();
+ Map optionalField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException);
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ if (optionalField != null && optionalField.isEmpty()) {
+ validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope));
+ }
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ return optionalField;
+ }
+
public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
Map map,
String settingName,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java
new file mode 100644
index 0000000000000..0d68b6ebbe077
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.custom.CustomActionVisitor;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+
+import java.util.Map;
+import java.util.Objects;
+
+public class CustomModel extends Model {
+ private final CustomRateLimitServiceSettings rateLimitServiceSettings;
+
+ public CustomModel(ModelConfigurations configurations, ModelSecrets secrets, CustomRateLimitServiceSettings rateLimitServiceSettings) {
+ super(configurations, secrets);
+ this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
+ }
+
+ public static CustomModel of(CustomModel model, Map taskSettings) {
+ var requestTaskSettings = CustomTaskSettings.fromMap(taskSettings);
+ return new CustomModel(model, CustomTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
+ }
+
+ public CustomModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets,
+ ConfigurationParseContext context
+ ) {
+ this(
+ modelId,
+ taskType,
+ service,
+ CustomServiceSettings.fromMap(serviceSettings, context, taskType),
+ CustomTaskSettings.fromMap(taskSettings),
+ CustomSecretSettings.fromMap(secrets)
+ );
+ }
+
+ // should only be used for testing
+ CustomModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ CustomServiceSettings serviceSettings,
+ CustomTaskSettings taskSettings,
+ @Nullable CustomSecretSettings secretSettings
+ ) {
+ this(
+ new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
+ new ModelSecrets(secretSettings),
+ serviceSettings
+ );
+ }
+
+ protected CustomModel(CustomModel model, TaskSettings taskSettings) {
+ super(model, taskSettings);
+ rateLimitServiceSettings = model.rateLimitServiceSettings();
+ }
+
+ protected CustomModel(CustomModel model, ServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ rateLimitServiceSettings = model.rateLimitServiceSettings();
+ }
+
+ @Override
+ public CustomServiceSettings getServiceSettings() {
+ return (CustomServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public CustomTaskSettings getTaskSettings() {
+ return (CustomTaskSettings) super.getTaskSettings();
+ }
+
+ @Override
+ public CustomSecretSettings getSecretSettings() {
+ return (CustomSecretSettings) super.getSecretSettings();
+ }
+
+ public ExecutableAction accept(CustomActionVisitor visitor, Map taskSettings, InputType inputType) {
+ return visitor.create(this, taskSettings, inputType);
+ }
+
+ public CustomRateLimitServiceSettings rateLimitServiceSettings() {
+ return rateLimitServiceSettings;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java
new file mode 100644
index 0000000000000..55641bad7ccaa
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+public interface CustomRateLimitServiceSettings {
+ RateLimitSettings rateLimitSettings();
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java
new file mode 100644
index 0000000000000..2c118869a854a
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
+
+public class CustomSecretSettings implements SecretSettings {
+ public static final String NAME = "custom_secret_settings";
+ public static final String SECRET_PARAMETERS = "secret_parameters";
+ private final Map secretParameters;
+
+ public static CustomSecretSettings fromMap(@Nullable Map map) {
+ if (map == null) {
+ return null;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ Map requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException);
+ if (requestSecretParamsMap == null) {
+ return null;
+ } else {
+ Map secureSecretParameters = new HashMap<>();
+ for (String paramKey : requestSecretParamsMap.keySet()) {
+ Object paramValue = requestSecretParamsMap.get(paramKey);
+ secureSecretParameters.put(paramKey, paramValue);
+ }
+ return new CustomSecretSettings(secureSecretParameters);
+ }
+ }
+
+ @Override
+ public SecretSettings newSecretSettings(Map newSecrets) {
+ return fromMap(new HashMap<>(newSecrets));
+ }
+
+ public CustomSecretSettings(@Nullable Map secretParameters) {
+ this.secretParameters = secretParameters;
+ }
+
+ public CustomSecretSettings(StreamInput in) throws IOException {
+ if (in.readBoolean()) {
+ secretParameters = in.readGenericMap();
+ } else {
+ secretParameters = null;
+ }
+ }
+
+ public Map getSecretParameters() {
+ return secretParameters;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (secretParameters != null) {
+ builder.field(SECRET_PARAMETERS, secretParameters);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.V_8_15_0;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ if (secretParameters != null) {
+ out.writeBoolean(true);
+ out.writeGenericMap(secretParameters);
+ } else {
+ out.writeBoolean(false);
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CustomSecretSettings that = (CustomSecretSettings) o;
+ return Objects.equals(secretParameters, that.secretParameters);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(secretParameters);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
new file mode 100644
index 0000000000000..a36f47af8e17b
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
@@ -0,0 +1,271 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.util.LazyInitializable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.InferenceServiceConfiguration;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SettingsConfiguration;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.inference.external.action.custom.CustomActionCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.external.request.custom.CustomUtils;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+
+import java.io.IOException;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
+
+public class CustomService extends SenderService {
+ private static final Logger logger = LogManager.getLogger(CustomService.class);
+ public static final String NAME = CustomUtils.SERVICE_NAME;
+
+ private static final EnumSet supportedTaskTypes = EnumSet.of(
+ TaskType.TEXT_EMBEDDING,
+ TaskType.SPARSE_EMBEDDING,
+ TaskType.RERANK,
+ TaskType.COMPLETION,
+ TaskType.CUSTOM
+ );
+
+ public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
+ super(factory, serviceComponents);
+ }
+
+ @Override
+ public String name() {
+ return NAME;
+ }
+
+ @Override
+ public void parseRequestConfig(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map config,
+ ActionListener parsedModelListener
+ ) {
+ try {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+ CustomModel model = createModel(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ serviceSettingsMap,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
+ ConfigurationParseContext.REQUEST
+ );
+
+ logModelConfig(model.getConfigurations());
+
+ throwIfNotEmptyMap(config, NAME);
+ throwIfNotEmptyMap(serviceSettingsMap, NAME);
+ throwIfNotEmptyMap(taskSettingsMap, NAME);
+
+ parsedModelListener.onResponse(model);
+ } catch (Exception e) {
+ parsedModelListener.onFailure(e);
+ }
+ }
+
+ @Override
+ public InferenceServiceConfiguration getConfiguration() {
+ return Configuration.get();
+ }
+
+ @Override
+ public EnumSet supportedTaskTypes() {
+ return supportedTaskTypes;
+ }
+
+ @Override
+ protected void doUnifiedCompletionInfer(
+ Model model,
+ UnifiedChatInput inputs,
+ TimeValue timeout,
+ ActionListener listener
+ ) {
+ throwUnsupportedUnifiedCompletionOperation(NAME);
+ }
+
+ private static CustomModel createModelWithoutLoggingDeprecations(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secretSettings,
+ String failureMessage
+ ) {
+ return createModel(
+ inferenceEntityId,
+ taskType,
+ serviceSettings,
+ taskSettings,
+ secretSettings,
+ failureMessage,
+ ConfigurationParseContext.PERSISTENT
+ );
+ }
+
+ private static CustomModel createModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secretSettings,
+ String failureMessage,
+ ConfigurationParseContext context
+ ) {
+ return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
+ }
+
+ @Override
+ public CustomModel parsePersistedConfigWithSecrets(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map config,
+ Map secrets
+ ) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+ Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
+
+ return createModelWithoutLoggingDeprecations(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ secretSettingsMap,
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ );
+ }
+
+ @Override
+ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+
+ return createModelWithoutLoggingDeprecations(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ null,
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ );
+ }
+
+ @Override
+ public void doInfer(
+ Model model,
+ InferenceInputs inputs,
+ Map taskSettings,
+ InputType inputType,
+ TimeValue timeout,
+ ActionListener listener
+ ) {
+ if (model instanceof CustomModel == false) {
+ listener.onFailure(createInvalidModelException(model));
+ return;
+ }
+
+ CustomModel customModel = (CustomModel) model;
+
+ var actionCreator = new CustomActionCreator(getSender(), getServiceComponents());
+
+ var action = customModel.accept(actionCreator, taskSettings, inputType);
+ action.execute(inputs, timeout, listener);
+ }
+
+ @Override
+ protected void doChunkedInfer(
+ Model model,
+ DocumentsOnlyInput inputs,
+ Map taskSettings,
+ InputType inputType,
+ TimeValue timeout,
+ ActionListener> listener
+ ) {
+ listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
+ }
+
+ /**
+ * For text embedding models get the embedding size and
+ * update the service settings.
+ *
+ * @param model The new model
+ * @param listener The listener
+ */
+ @Override
+ public void checkModelConfig(Model model, ActionListener listener) {
+ listener.onResponse(model);
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL;
+ }
+
+ public static class Configuration {
+ public static InferenceServiceConfiguration get() {
+ return configuration.getOrCompute();
+ }
+
+ private static final LazyInitializable configuration = new LazyInitializable<>(
+ () -> {
+ var configurationMap = new HashMap();
+ return new InferenceServiceConfiguration.Builder().setService(NAME)
+ .setName(NAME)
+ .setTaskTypes(supportedTaskTypes)
+ .setConfigurations(configurationMap)
+ .build();
+ }
+ );
+ }
+
+ private void logModelConfig(ModelConfigurations modelConfigurations) throws IOException {
+ XContentBuilder builder = XContentFactory.jsonBuilder();
+ XContentBuilder modelBuilder = modelConfigurations.toXContent(builder, EMPTY_PARAMS);
+ String jsonString = BytesReference.bytes(modelBuilder).utf8ToString();
+ logger.info("add custom model: " + jsonString);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
new file mode 100644
index 0000000000000..3399307113bb1
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
@@ -0,0 +1,635 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.methods.HttpPut;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.custom.CustomUtils.SERVICE_NAME;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+
+public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {
+ public static final String NAME = "custom_service_settings";
+ public static final String DESCRIPTION = "description";
+ public static final String VERSION = "version";
+ public static final String URL = "url";
+ public static final String PATH = "path";
+ public static final String QUERY_STRING = "query_string";
+ public static final String HEADERS = "headers";
+ public static final String REQUEST = "request";
+ public static final String REQUEST_FORMAT = "format";
+ public static final String REQUEST_CONTENT = "content";
+ public static final String RESPONSE = "response";
+ public static final String JSON_PARSER = "json_parser";
+
+ public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings";
+
+ public static final String SPARSE_EMBEDDING_RESULT = "sparse_result";
+ public static final String SPARSE_RESULT_PATH = "path";
+ public static final String SPARSE_RESULT_VALUE = "value";
+ public static final String SPARSE_EMBEDDING_PARSER_TOKEN = "sparse_token";
+ public static final String SPARSE_EMBEDDING_PARSER_WEIGHT = "sparse_weight";
+
+ public static final String RERANK_PARSER_INDEX = "reranked_index";
+ public static final String RERANK_PARSER_SCORE = "relevance_score";
+ public static final String RERANK_PARSER_DOCUMENT_TEXT = "document_text";
+
+ public static final String COMPLETION_PARSER_RESULT = "completion_result";
+
+ public static final String REQUEST_FORMAT_JSON = "json";
+ public static final String REQUEST_FORMAT_STRING = "string";
+
+ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
+
+ public static CustomServiceSettings fromMap(Map map, ConfigurationParseContext context, TaskType taskType) {
+ ValidationException validationException = new ValidationException();
+
+ SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
+ Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
+
+ String description = extractOptionalString(map, DESCRIPTION, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ String version = extractOptionalString(map, VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ String serviceType = taskType.toString();
+ String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ Map pathMap = extractRequiredMap(map, PATH, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ if (pathMap == null) {
+ throw validationException;
+ }
+ if (pathMap.size() > 1) {
+ validationException.addValidationError("[" + PATH + "] only support one endpoint, but found [" + pathMap.keySet() + "]");
+ throw validationException;
+ }
+
+ String path = pathMap.keySet().iterator().next();
+ Map pathContent = extractRequiredMap(pathMap, path, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ if (pathContent == null) {
+ throw validationException;
+ }
+ if (pathContent.size() > 1) {
+ validationException.addValidationError("[" + PATH + "] only support one method, but found [" + pathContent.keySet() + "]");
+ throw validationException;
+ }
+
+ String method = pathContent.keySet().iterator().next();
+ switch (method.toUpperCase()) {
+ case HttpGet.METHOD_NAME:
+ case HttpPut.METHOD_NAME:
+ case HttpPost.METHOD_NAME:
+ break;
+ default:
+ validationException.addValidationError(
+ String.format(
+ Locale.ROOT,
+ "unsupported http method [" + method + "], support [%s], [%s] and [%s]",
+ HttpGet.METHOD_NAME,
+ HttpPut.METHOD_NAME,
+ HttpPost.METHOD_NAME
+ )
+ );
+ }
+
+ Map modelParamsMap = extractRequiredMap(
+ pathContent,
+ method,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+ if (modelParamsMap == null) {
+ throw validationException;
+ }
+
+ String queryString = extractOptionalString(modelParamsMap, QUERY_STRING, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ Map headers = extractOptionalMap(
+ modelParamsMap,
+ HEADERS,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+
+ Map requestBodyMap = extractRequiredMap(
+ modelParamsMap,
+ REQUEST,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+
+ if (requestBodyMap == null) {
+ throw validationException;
+ }
+
+ String requestFormat = extractRequiredString(
+ requestBodyMap,
+ REQUEST_FORMAT,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+ if (requestFormat == null) {
+ throw validationException;
+ }
+
+ Map requestContent = null;
+ String requestContentString = null;
+
+ switch (requestFormat) {
+ // todo support json format
+ case REQUEST_FORMAT_STRING:
+ requestContentString = extractRequiredString(
+ requestBodyMap,
+ REQUEST_CONTENT,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+ break;
+ default:
+ validationException.addValidationError(
+ Strings.format(
+ "[%s.%s] does not support value [%s]. It should be [%s]",
+ REQUEST,
+ REQUEST_FORMAT,
+ requestFormat,
+ REQUEST_FORMAT_STRING
+ )
+ );
+ }
+
+ ResponseJsonParser responseJsonParser = null;
+ Map responseParserMap = null;
+ switch (taskType) {
+ case TEXT_EMBEDDING:
+ case SPARSE_EMBEDDING:
+ case RERANK:
+ case COMPLETION:
+ responseParserMap = extractRequiredMap(modelParamsMap, RESPONSE, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ if (responseParserMap == null) {
+ throw validationException;
+ }
+ Map jsonParserMap = extractRequiredMap(
+ responseParserMap,
+ JSON_PARSER,
+ ModelConfigurations.SERVICE_SETTINGS,
+ validationException
+ );
+ if (jsonParserMap == null) {
+ throw validationException;
+ }
+ responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);
+ throwIfNotEmptyMap(responseParserMap, NAME);
+ }
+
+ RateLimitSettings rateLimitSettings = RateLimitSettings.of(
+ map,
+ DEFAULT_RATE_LIMIT_SETTINGS,
+ validationException,
+ CustomService.NAME,
+ context
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CustomServiceSettings(
+ similarity,
+ dims,
+ maxInputTokens,
+ description,
+ version,
+ serviceType,
+ url,
+ path,
+ method,
+ queryString,
+ headers,
+ requestFormat,
+ requestContent,
+ requestContentString,
+ responseJsonParser,
+ rateLimitSettings
+ );
+ }
+
+ private final SimilarityMeasure similarity;
+ private final Integer dimensions;
+ private final Integer maxInputTokens;
+ private final String description;
+ private final String version;
+ private final String url;
+ private final String serviceType;
+ private final String path;
+ private final String method;
+ private final String queryString;
+ private final Map headers;
+ private final String requestFormat;
+ private final Map requestContent;
+ private final String requestContentString;
+ private final ResponseJsonParser responseJsonParser;
+ private final RateLimitSettings rateLimitSettings;
+
+ public CustomServiceSettings(
+ @Nullable SimilarityMeasure similarity,
+ @Nullable Integer dimensions,
+ @Nullable Integer maxInputTokens,
+ String description,
+ String version,
+ String serviceType,
+ String url,
+ String path,
+ String method,
+ String queryString,
+ Map headers,
+ String requestFormat,
+ Map requestContent,
+ String requestContentString,
+ ResponseJsonParser responseJsonParser,
+ @Nullable RateLimitSettings rateLimitSettings
+ ) {
+ this.similarity = similarity;
+ this.dimensions = dimensions;
+ this.maxInputTokens = maxInputTokens;
+ this.description = description;
+ this.version = version;
+ this.serviceType = serviceType;
+ this.url = url;
+ this.path = path;
+ this.method = method;
+ this.queryString = queryString;
+ this.headers = headers;
+ this.requestFormat = requestFormat;
+ this.requestContent = requestContent;
+ this.requestContentString = requestContentString;
+ this.responseJsonParser = responseJsonParser;
+ this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+ }
+
+ public CustomServiceSettings(StreamInput in) throws IOException {
+ similarity = in.readOptionalEnum(SimilarityMeasure.class);
+ dimensions = in.readOptionalVInt();
+ maxInputTokens = in.readOptionalVInt();
+ description = in.readOptionalString();
+ version = in.readOptionalString();
+ serviceType = in.readString();
+ url = in.readString();
+ path = in.readString();
+ method = in.readString();
+ queryString = in.readOptionalString();
+ if (in.readBoolean()) {
+ headers = in.readGenericMap();
+ } else {
+ headers = null;
+ }
+ requestFormat = in.readString();
+ if (in.readBoolean()) {
+ requestContent = in.readGenericMap();
+ } else {
+ requestContent = null;
+ }
+ requestContentString = in.readOptionalString();
+ if (in.readBoolean()) {
+ responseJsonParser = new ResponseJsonParser(in);
+ } else {
+ responseJsonParser = null;
+ }
+ rateLimitSettings = new RateLimitSettings(in);
+ }
+
+ @Override
+ public SimilarityMeasure similarity() {
+ return similarity;
+ }
+
+ @Override
+ public Integer dimensions() {
+ return dimensions;
+ }
+
+ @Override
+ public DenseVectorFieldMapper.ElementType elementType() {
+ return DenseVectorFieldMapper.ElementType.FLOAT;
+ }
+
+ public URI getUri() {
+ return null;
+ }
+
+ public SimilarityMeasure getSimilarity() {
+ return similarity;
+ }
+
+ public Integer getDimensions() {
+ return dimensions;
+ }
+
+ public Integer getMaxInputTokens() {
+ return maxInputTokens;
+ }
+
+ public String getUrl() {
+ return url;
+ }
+
+ public String getServiceType() {
+ return serviceType;
+ }
+
+ public String getPath() {
+ return path;
+ }
+
+ public String getMethod() {
+ return method;
+ }
+
+ public String getQueryString() {
+ return queryString;
+ }
+
+ public Map getHeaders() {
+ return headers;
+ }
+
+ public String getRequestFormat() {
+ return requestFormat;
+ }
+
+ public Map getRequestContent() {
+ return requestContent;
+ }
+
+ public String getRequestContentString() {
+ return requestContentString;
+ }
+
+ public ResponseJsonParser getResponseJsonParser() {
+ return responseJsonParser;
+ }
+
+ @Override
+ public RateLimitSettings rateLimitSettings() {
+ return rateLimitSettings;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ toXContentFragment(builder, params);
+
+ builder.endObject();
+ return builder;
+ }
+
+ public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
+ return toXContentFragmentOfExposedFields(builder, params);
+ }
+
+ @Override
+ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+ if (similarity != null) {
+ builder.field(SIMILARITY, similarity);
+ }
+ if (dimensions != null) {
+ builder.field(DIMENSIONS, dimensions);
+ }
+ if (maxInputTokens != null) {
+ builder.field(MAX_INPUT_TOKENS, maxInputTokens);
+ }
+ if (description != null) {
+ builder.field(DESCRIPTION, description);
+ }
+ if (version != null) {
+ builder.field(VERSION, version);
+ }
+ if (url != null) {
+ builder.field(URL, url);
+ }
+ builder.startObject(PATH);
+ {
+ builder.startObject(path);
+ {
+ builder.startObject(method);
+ {
+ if (queryString != null) {
+ builder.field(QUERY_STRING, queryString);
+ }
+ if (headers != null) {
+ builder.field(HEADERS, headers);
+ }
+
+ builder.startObject(REQUEST);
+ {
+ if (requestFormat != null) {
+ builder.field(REQUEST_FORMAT, requestFormat);
+ }
+ if (requestContent != null) {
+ builder.field(REQUEST_CONTENT, requestContent);
+ } else if (requestContentString != null) {
+ builder.field(REQUEST_CONTENT, requestContentString);
+ }
+ }
+ builder.endObject();
+
+ if (responseJsonParser != null) {
+ builder.startObject(RESPONSE);
+ {
+ responseJsonParser.toXContent(builder, params);
+ }
+ builder.endObject();
+ }
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+
+ rateLimitSettings.toXContent(builder, params);
+
+ return builder;
+ }
+
+ @Override
+ public ToXContentObject getFilteredXContentObject() {
+ return this;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalEnum(similarity);
+ out.writeOptionalVInt(dimensions);
+ out.writeOptionalVInt(maxInputTokens);
+ out.writeOptionalString(description);
+ out.writeOptionalString(version);
+ out.writeString(serviceType);
+ out.writeString(url);
+ out.writeString(path);
+ out.writeString(method);
+ out.writeOptionalString(queryString);
+ if (headers != null) {
+ out.writeBoolean(true);
+ out.writeGenericMap(headers);
+ } else {
+ out.writeBoolean(false);
+ }
+ out.writeString(requestFormat);
+ if (requestContent != null) {
+ out.writeBoolean(true);
+ out.writeGenericMap(requestContent);
+ } else {
+ out.writeBoolean(false);
+ }
+ out.writeOptionalString(requestContentString);
+ if (responseJsonParser != null) {
+ out.writeBoolean(true);
+ responseJsonParser.writeTo(out);
+ } else {
+ out.writeBoolean(false);
+ }
+ rateLimitSettings.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CustomServiceSettings that = (CustomServiceSettings) o;
+ return Objects.equals(similarity, that.similarity)
+ && Objects.equals(dimensions, that.dimensions)
+ && Objects.equals(maxInputTokens, that.maxInputTokens)
+ && Objects.equals(description, that.description)
+ && Objects.equals(version, that.version)
+ && Objects.equals(serviceType, that.serviceType)
+ && Objects.equals(url, that.url)
+ && Objects.equals(path, that.path)
+ && Objects.equals(method, that.method)
+ && Objects.equals(queryString, that.queryString)
+ && Objects.equals(headers, that.headers)
+ && Objects.equals(requestFormat, that.requestFormat)
+ && Objects.equals(requestContent, that.requestContent)
+ && Objects.equals(requestContentString, that.requestContentString)
+ && Objects.equals(responseJsonParser, that.responseJsonParser)
+ && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ similarity,
+ dimensions,
+ maxInputTokens,
+ description,
+ version,
+ serviceType,
+ url,
+ path,
+ method,
+ queryString,
+ headers,
+ requestFormat,
+ requestContent,
+ requestContentString,
+ responseJsonParser,
+ rateLimitSettings
+ );
+ }
+
+ @Override
+ public String modelId() {
+ return SERVICE_NAME;
+ }
+
+ private static ResponseJsonParser extractResponseParser(
+ TaskType taskType,
+ Map responseParserMap,
+ ValidationException validationException
+ ) {
+ return switch (taskType) {
+ case TEXT_EMBEDDING -> extractTextEmbeddingResponseParser(responseParserMap, validationException);
+ case SPARSE_EMBEDDING -> extractSparseTextEmbeddingResponseParser(responseParserMap, validationException);
+ case RERANK -> extractRerankResponseParser(responseParserMap, validationException);
+ case COMPLETION -> extractCompletionResponseParser(responseParserMap, validationException);
+ default -> throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "response json parser does not support TaskType [%s]", taskType)
+ );
+ };
+ }
+
+ private static ResponseJsonParser extractTextEmbeddingResponseParser(
+ Map responseParserMap,
+ ValidationException validationException
+ ) {
+ return new ResponseJsonParser(TaskType.TEXT_EMBEDDING, responseParserMap, validationException);
+ }
+
+ private static ResponseJsonParser extractSparseTextEmbeddingResponseParser(
+ Map responseParserMap,
+ ValidationException validationException
+ ) {
+ return new ResponseJsonParser(TaskType.SPARSE_EMBEDDING, responseParserMap, validationException);
+ }
+
+ private static ResponseJsonParser extractRerankResponseParser(
+ Map responseParserMap,
+ ValidationException validationException
+ ) {
+ return new ResponseJsonParser(TaskType.RERANK, responseParserMap, validationException);
+ }
+
+ private static ResponseJsonParser extractCompletionResponseParser(
+ Map responseParserMap,
+ ValidationException validationException
+ ) {
+ return new ResponseJsonParser(TaskType.COMPLETION, responseParserMap, validationException);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java
new file mode 100644
index 0000000000000..9fed14a2d9a9f
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java
@@ -0,0 +1,164 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
+
+public class CustomTaskSettings implements TaskSettings {
+ public static final String NAME = "custom_task_settings";
+
+ public static final String PARAMETERS = "parameters";
+ public static final String IGNORE_PLACEHOLDER_CHECK = "ignore_placeholder_check";
+
+ static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(null, null);
+
+ public static CustomTaskSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+ if (map == null || map.isEmpty()) {
+ return EMPTY_SETTINGS;
+ }
+
+ Map parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException);
+ Boolean ignorePlaceholderCheck = extractOptionalBoolean(map, IGNORE_PLACEHOLDER_CHECK, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CustomTaskSettings(parameters, ignorePlaceholderCheck);
+ }
+
+ /**
+ * Creates a new {@link CustomTaskSettings}
+ * by preferring non-null fields from the request settings over the original settings.
+ * @param originalSettings the settings stored as part of the inference entity configuration
+ * @param requestTaskSettings the settings passed in within the task_settings field of the request
+ * @return a constructed {@link CustomTaskSettings}
+ */
+
+ public static CustomTaskSettings of(CustomTaskSettings originalSettings, CustomTaskSettings requestTaskSettings) {
+ // If both requestTaskSettings.getParameters() and originalSettings.getParameters() are defined
+ // the maps should be merged.
+ if (originalSettings != null
+ && originalSettings.parameters != null
+ && requestTaskSettings != null
+ && requestTaskSettings.parameters != null) {
+ var copy = new HashMap<>(originalSettings.parameters);
+ requestTaskSettings.parameters.forEach((key, value) -> copy.merge(key, value, (originalValue, requestValue) -> requestValue));
+ Boolean ignorePlaceholderCheck = requestTaskSettings.getIgnorePlaceholderCheck() != null
+ ? requestTaskSettings.getIgnorePlaceholderCheck()
+ : originalSettings.getIgnorePlaceholderCheck();
+ return new CustomTaskSettings(copy, ignorePlaceholderCheck);
+ } else {
+ return new CustomTaskSettings(
+ requestTaskSettings.getParameters() != null ? requestTaskSettings.getParameters() : originalSettings.getParameters(),
+ requestTaskSettings.getIgnorePlaceholderCheck() != null
+ ? requestTaskSettings.getIgnorePlaceholderCheck()
+ : originalSettings.getIgnorePlaceholderCheck()
+ );
+ }
+ }
+
+ private final Map parameters;
+ private Boolean ignorePlaceholderCheck;
+
+ public CustomTaskSettings(StreamInput in) throws IOException {
+ parameters = in.readBoolean() ? in.readGenericMap() : null;
+ ignorePlaceholderCheck = in.readOptionalBoolean();
+ }
+
+ public CustomTaskSettings(Map parameters, Boolean ignorePlaceholderCheck) {
+ this.parameters = parameters;
+ this.ignorePlaceholderCheck = ignorePlaceholderCheck;
+ }
+
+ public Map getParameters() {
+ return parameters;
+ }
+
+ public Boolean getIgnorePlaceholderCheck() {
+ return ignorePlaceholderCheck;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (parameters != null) {
+ builder.field(PARAMETERS, parameters);
+ }
+ if (ignorePlaceholderCheck != null) {
+ builder.field(IGNORE_PLACEHOLDER_CHECK, ignorePlaceholderCheck);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ public Map getParametersOrDefault() {
+ return parameters;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.V_8_15_0;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ if (parameters == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ out.writeGenericMap(parameters);
+ }
+ out.writeOptionalBoolean(ignorePlaceholderCheck);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CustomTaskSettings that = (CustomTaskSettings) o;
+ return Objects.equals(parameters, that.parameters) && Objects.equals(ignorePlaceholderCheck, that.ignorePlaceholderCheck);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(parameters, ignorePlaceholderCheck);
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return (parameters == null || parameters.isEmpty()) && ignorePlaceholderCheck == null;
+ }
+
+ @Override
+ public TaskSettings updatedTaskSettings(Map newSettings) {
+ CustomTaskSettings updatedSettings = CustomTaskSettings.fromMap(new HashMap<>(newSettings));
+ return of(this, updatedSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java
new file mode 100644
index 0000000000000..6c2c1bfda882d
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java
@@ -0,0 +1,274 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.COMPLETION_PARSER_RESULT;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_DOCUMENT_TEXT;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_INDEX;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_SCORE;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_RESULT;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_RESULT_PATH;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_RESULT_VALUE;
+import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS;
+
+public class ResponseJsonParser {
+ public static final String FIELD_NAME = JSON_PARSER;
+
+ private String taskTypeStr;
+
+ private String textEmbeddingsPath;
+
+ private String sparseResultPath;
+ private String sparseTokenPath;
+ private String sparseWeightPath;
+
+ private String rerankedIndexPath;
+ private String relevanceScorePath;
+ private String documentTextPath;
+
+ private String completionResultPath;
+
+ public ResponseJsonParser(TaskType taskType, Map responseParserMap, ValidationException validationException) {
+ this.taskTypeStr = taskType.toString();
+ switch (taskType) {
+ case TEXT_EMBEDDING -> textEmbeddingsPath = extractRequiredString(
+ responseParserMap,
+ TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ JSON_PARSER,
+ validationException
+ );
+ case SPARSE_EMBEDDING -> {
+ Map sparseResultMap = extractRequiredMap(
+ responseParserMap,
+ SPARSE_EMBEDDING_RESULT,
+ JSON_PARSER,
+ validationException
+ );
+ if (sparseResultMap == null) {
+ throw validationException;
+ }
+ sparseResultPath = extractRequiredString(sparseResultMap, SPARSE_RESULT_PATH, JSON_PARSER, validationException);
+ Map sparseResultValueMap = extractRequiredMap(
+ sparseResultMap,
+ SPARSE_RESULT_VALUE,
+ JSON_PARSER,
+ validationException
+ );
+ if (sparseResultValueMap == null) {
+ throw validationException;
+ }
+ sparseTokenPath = extractRequiredString(
+ sparseResultValueMap,
+ SPARSE_EMBEDDING_PARSER_TOKEN,
+ JSON_PARSER,
+ validationException
+ );
+ sparseWeightPath = extractRequiredString(
+ sparseResultValueMap,
+ SPARSE_EMBEDDING_PARSER_WEIGHT,
+ JSON_PARSER,
+ validationException
+ );
+ }
+ case RERANK -> {
+ rerankedIndexPath = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException);
+
+ relevanceScorePath = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException);
+
+ documentTextPath = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException);
+ }
+ case COMPLETION -> completionResultPath = extractRequiredString(
+ responseParserMap,
+ COMPLETION_PARSER_RESULT,
+ JSON_PARSER,
+ validationException
+ );
+ default -> throw new IllegalArgumentException(
+ String.format(Locale.ROOT, "json parser does not support taskType [%s]", taskType)
+ );
+ }
+ }
+
+ public ResponseJsonParser(StreamInput in) throws IOException {
+ this.taskTypeStr = in.readString();
+ TaskType taskType = TaskType.fromString(this.taskTypeStr);
+ switch (taskType) {
+ case TEXT_EMBEDDING -> this.textEmbeddingsPath = in.readString();
+ case SPARSE_EMBEDDING -> {
+ this.sparseResultPath = in.readString();
+ this.sparseTokenPath = in.readString();
+ this.sparseWeightPath = in.readString();
+ }
+ case RERANK -> {
+ this.rerankedIndexPath = in.readOptionalString();
+ this.relevanceScorePath = in.readString();
+ this.documentTextPath = in.readOptionalString();
+ }
+ case COMPLETION -> this.completionResultPath = in.readString();
+ }
+ }
+
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(this.taskTypeStr);
+ TaskType taskType = TaskType.fromString(this.taskTypeStr);
+ switch (taskType) {
+ case TEXT_EMBEDDING -> out.writeString(this.textEmbeddingsPath);
+ case SPARSE_EMBEDDING -> {
+ out.writeString(this.sparseResultPath);
+ out.writeString(this.sparseTokenPath);
+ out.writeString(this.sparseWeightPath);
+ }
+ case RERANK -> {
+ out.writeOptionalString(this.rerankedIndexPath);
+ out.writeString(this.relevanceScorePath);
+ out.writeOptionalString(this.documentTextPath);
+ }
+ case COMPLETION -> out.writeString(this.completionResultPath);
+ }
+ }
+
+ public String getTextEmbeddingsPath() {
+ return textEmbeddingsPath;
+ }
+
+ public String getSparseResultPath() {
+ return sparseResultPath;
+ }
+
+ public String getSparseTokenPath() {
+ return sparseTokenPath;
+ }
+
+ public String getSparseWeightPath() {
+ return sparseWeightPath;
+ }
+
+ public String getRerankedIndexPath() {
+ return rerankedIndexPath;
+ }
+
+ public String getRelevanceScorePath() {
+ return relevanceScorePath;
+ }
+
+ public String getDocumentTextPath() {
+ return documentTextPath;
+ }
+
+ public String getCompletionResultPath() {
+ return completionResultPath;
+ }
+
+ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+ builder.startObject(FIELD_NAME);
+ {
+ switch (TaskType.fromString(this.taskTypeStr)) {
+ case TEXT_EMBEDDING -> builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath);
+ case SPARSE_EMBEDDING -> {
+ builder.startObject(SPARSE_EMBEDDING_RESULT);
+ {
+ builder.field(SPARSE_RESULT_PATH, sparseResultPath);
+ builder.startObject(SPARSE_RESULT_VALUE);
+ {
+ builder.field(SPARSE_EMBEDDING_PARSER_TOKEN, sparseTokenPath);
+ builder.field(SPARSE_EMBEDDING_PARSER_WEIGHT, sparseWeightPath);
+ }
+ builder.endObject();
+ }
+ builder.endObject();
+ }
+ case RERANK -> {
+ if (rerankedIndexPath != null) {
+ builder.field(RERANK_PARSER_INDEX, rerankedIndexPath);
+ }
+ builder.field(RERANK_PARSER_SCORE, relevanceScorePath);
+ if (documentTextPath != null) {
+ builder.field(RERANK_PARSER_DOCUMENT_TEXT, documentTextPath);
+ }
+ }
+ case COMPLETION -> builder.field(COMPLETION_PARSER_RESULT, completionResultPath);
+ }
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ public static ResponseJsonParser of(
+ Map map,
+ ValidationException validationException,
+ String serviceName,
+ ConfigurationParseContext context
+ ) {
+ Map responseParserMap = extractRequiredMap(map, FIELD_NAME, JSON_PARSER, validationException);
+ if (responseParserMap == null) {
+ throw validationException;
+ }
+ String taskTypeStr = extractRequiredString(responseParserMap, TaskType.NAME, FIELD_NAME, validationException);
+ if (taskTypeStr == null) {
+ throw validationException;
+ }
+ TaskType taskType = TaskType.fromString(taskTypeStr);
+ ResponseJsonParser responseJsonParser = new ResponseJsonParser(taskType, responseParserMap, validationException);
+
+ if (ConfigurationParseContext.isRequestContext(context)) {
+ throwIfNotEmptyMap(responseParserMap, serviceName);
+ }
+ return responseJsonParser;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ResponseJsonParser that = (ResponseJsonParser) o;
+ return Objects.equals(taskTypeStr, that.taskTypeStr)
+ && Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath)
+ && Objects.equals(sparseResultPath, that.sparseResultPath)
+ && Objects.equals(sparseTokenPath, that.sparseTokenPath)
+ && Objects.equals(sparseWeightPath, that.sparseWeightPath)
+ && Objects.equals(rerankedIndexPath, that.rerankedIndexPath)
+ && Objects.equals(relevanceScorePath, that.relevanceScorePath)
+ && Objects.equals(documentTextPath, that.documentTextPath)
+ && Objects.equals(completionResultPath, that.completionResultPath);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ taskTypeStr,
+ textEmbeddingsPath,
+ sparseResultPath,
+ sparseTokenPath,
+ sparseWeightPath,
+ rerankedIndexPath,
+ relevanceScorePath,
+ documentTextPath,
+ completionResultPath
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java
new file mode 100644
index 0000000000000..470f5f4850ffc
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.custom;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.custom.CustomModel;
+import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
+import org.hamcrest.MatcherAssert;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.List;
+
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class CustomRequestTests extends ESTestCase {
+ public void testCreateRequest() throws IOException {
+ // create request
+ var request = createRequest(null, List.of("abc"), CustomModelTests.getTestModel());
+ var httpRequest = request.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ String queryStringRes = "?query=" + CustomModelTests.taskSettingsValue;
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+ var uri = httpPost.getURI().toString();
+ MatcherAssert.assertThat(uri, is(CustomModelTests.url + CustomModelTests.path + queryStringRes));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(CustomModelTests.secretSettingsValue));
+
+ String requestBody = convertStreamToString(httpPost.getEntity().getContent());
+ MatcherAssert.assertThat(requestBody, is("\"input\":\"[\"abc\"]\""));
+ }
+
+ public static CustomRequest createRequest(String query, List input, CustomModel model) {
+ return new CustomRequest(query, input, model);
+ }
+
+ private static String convertStreamToString(InputStream inputStream) {
+ StringBuilder stringBuilder = new StringBuilder();
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ stringBuilder.append(line);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return stringBuilder.toString();
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java
new file mode 100644
index 0000000000000..f73583d2a556d
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java
@@ -0,0 +1,229 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.custom;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.ml.search.WeightedToken;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.custom.CustomRequestTests;
+import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
+import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CustomResponseEntityTests extends ESTestCase {
+
+ public void testFromTextEmbeddingResponse() throws IOException {
+ String responseJson = """
+ {
+ "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
+ "latency": 38,
+ "usage": {
+ "token_count": 3072
+ },
+ "result": {
+ "embeddings": [
+ {
+ "index": 0,
+ "embedding": [
+ -0.02868066355586052,
+ 0.022033605724573135
+ ]
+ }
+ ]
+ }
+ }
+ """;
+
+ var request = CustomRequestTests.createRequest(null, List.of("abc"), CustomModelTests.getTestModel());
+ InferenceServiceResults results = CustomResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(results, instanceOf(TextEmbeddingFloatResults.class));
+ assertThat(
+ ((TextEmbeddingFloatResults) results).embeddings(),
+ is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f })))
+ );
+ }
+
+ public void testFromSparseEmbeddingResponse() throws IOException {
+ String responseJson = """
+ {
+ "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
+ "latency": 22,
+ "usage": {
+ "token_count": 11
+ },
+ "result": {
+ "sparse_embeddings": [
+ {
+ "index": 0,
+ "embedding": [
+ {
+ "tokenId": 6,
+ "weight": 0.10137939453125
+ },
+ {
+ "tokenId": 163040,
+ "weight": 0.2841796875
+ }
+ ]
+ }
+ ]
+ }
+ }
+ """;
+
+ Map jsonParserMap = new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_EMBEDDING_RESULT,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_RESULT_PATH,
+ "$.result.sparse_embeddings[*]",
+ CustomServiceSettings.SPARSE_RESULT_VALUE,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN,
+ "$.embedding[*].tokenId",
+ CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT,
+ "$.embedding[*].weight"
+ )
+ )
+ )
+ )
+ )
+ );
+
+ var request = CustomRequestTests.createRequest(
+ null,
+ List.of("abc"),
+ CustomModelTests.getTestModel(TaskType.SPARSE_EMBEDDING, jsonParserMap)
+ );
+
+ InferenceServiceResults results = CustomResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+ assertThat(results, instanceOf(SparseEmbeddingResults.class));
+
+ SparseEmbeddingResults sparseEmbeddingResults = (SparseEmbeddingResults) results;
+
+ List embeddingList = new ArrayList<>();
+ List weightedTokens = new ArrayList<>();
+ weightedTokens.add(new WeightedToken("6", 0.10137939453125f));
+ weightedTokens.add(new WeightedToken("163040", 0.2841796875f));
+ embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false));
+
+ for (int i = 0; i < embeddingList.size(); i++) {
+ assertThat(sparseEmbeddingResults.embeddings().get(i), is(embeddingList.get(i)));
+ }
+ }
+
+ public void testFromRerankResponse() throws IOException {
+ String responseJson = """
+ {
+ "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f",
+ "latency": 564.903929,
+ "usage": {
+ "doc_count": 2
+ },
+ "result": {
+ "scores":[
+ {
+ "index":1,
+ "score": 1.37
+ },
+ {
+ "index":0,
+ "score": -0.3
+ }
+ ]
+ }
+ }
+ """;
+
+ Map jsonParserMap = new HashMap<>(
+ Map.of(
+ CustomServiceSettings.RERANK_PARSER_INDEX,
+ "$.result.scores[*].index",
+ CustomServiceSettings.RERANK_PARSER_SCORE,
+ "$.result.scores[*].score"
+ )
+ );
+
+ var request = CustomRequestTests.createRequest(null, List.of("abc"), CustomModelTests.getTestModel(TaskType.RERANK, jsonParserMap));
+
+ InferenceServiceResults results = CustomResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(results, instanceOf(RankedDocsResults.class));
+ var expected = new ArrayList();
+ expected.add(new RankedDocsResults.RankedDoc(1, 1.37F, null));
+ expected.add(new RankedDocsResults.RankedDoc(0, -0.3F, null));
+
+ for (int i = 0; i < ((RankedDocsResults) results).getRankedDocs().size(); i++) {
+ assertThat(((RankedDocsResults) results).getRankedDocs().get(i).index(), is(expected.get(i).index()));
+ }
+ }
+
+ public void testFromCompletionResponse() throws IOException {
+ String responseJson = """
+ {
+ "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f",
+ "latency": 564.903929,
+ "result": {
+ "text":"completion results"
+ },
+ "usage": {
+ "output_tokens": 6320,
+ "input_tokens": 35,
+ "total_tokens": 6355
+ }
+ }
+ """;
+
+ Map jsonParserMap = new HashMap<>(Map.of(CustomServiceSettings.COMPLETION_PARSER_RESULT, "$.result.text"));
+
+ var request = CustomRequestTests.createRequest(
+ null,
+ List.of("abc"),
+ CustomModelTests.getTestModel(TaskType.COMPLETION, jsonParserMap)
+ );
+
+ InferenceServiceResults results = CustomResponseEntity.fromResponse(
+ request,
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(results, instanceOf(ChatCompletionResults.class));
+ ChatCompletionResults chatCompletionResults = (ChatCompletionResults) results;
+ assertThat(chatCompletionResults.getResults().size(), is(1));
+ assertThat(chatCompletionResults.getResults().get(0).content(), is("completion results"));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
new file mode 100644
index 0000000000000..0ef4ee71e7622
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import io.netty.handler.codec.http.HttpMethod;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.request.custom.CustomUtils;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class CustomModelTests extends ESTestCase {
+ public static String taskSettingsKey = "test_taskSettings_key";
+ public static String taskSettingsValue = "test_taskSettings_value";
+
+ public static String secretSettingsKey = "test_secret_key";
+ public static String secretSettingsValue = "test_secret_value";
+ public static String url = "http://www.abc.com";
+ public static String path = "/endpoint";
+
+ public void testOverride() {
+ var model = createModel(
+ "service",
+ TaskType.TEXT_EMBEDDING,
+ CustomServiceSettingsTests.createRandom(),
+ CustomTaskSettingsTests.createRandom(),
+ CustomSecretSettingsTests.createRandom()
+ );
+
+ var overriddenModel = CustomModel.of(model, Map.of());
+ MatcherAssert.assertThat(overriddenModel, is(model));
+ }
+
+ public static CustomModel createModel(
+ String modelId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets
+ ) {
+ return new CustomModel(modelId, taskType, CustomUtils.SERVICE_NAME, serviceSettings, taskSettings, secrets, null);
+ }
+
+ public static CustomModel createModel(
+ String modelId,
+ TaskType taskType,
+ CustomServiceSettings serviceSettings,
+ CustomTaskSettings taskSettings,
+ @Nullable CustomSecretSettings secretSettings
+ ) {
+ return new CustomModel(modelId, taskType, CustomUtils.SERVICE_NAME, serviceSettings, taskSettings, secretSettings);
+ }
+
+ public static CustomModel getTestModel() {
+ TaskType taskType = TaskType.TEXT_EMBEDDING;
+ Map jsonParserMap = new HashMap<>(
+ Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ );
+ return getTestModel(taskType, jsonParserMap);
+ }
+
+ public static CustomModel getTestModel(TaskType taskType, Map jsonParserMap) {
+ // service settings
+ Integer dims = 1536;
+ Integer maxInputTokens = 512;
+ String description = "test fromMap";
+ String version = "v1";
+ String serviceType = taskType.toString();
+ String method = HttpMethod.POST.name();
+ String queryString = "?query=${" + taskSettingsKey + "}";
+ Map headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}");
+ String requestFormat = CustomServiceSettings.REQUEST_FORMAT_STRING;
+ String requestContentString = "\"input\":\"${input}\"";
+
+ ResponseJsonParser responseJsonParser = new ResponseJsonParser(taskType, jsonParserMap, new ValidationException());
+
+ CustomServiceSettings serviceSettings = new CustomServiceSettings(
+ SimilarityMeasure.DOT_PRODUCT,
+ dims,
+ maxInputTokens,
+ description,
+ version,
+ serviceType,
+ url,
+ path,
+ method,
+ queryString,
+ headers,
+ requestFormat,
+ null,
+ requestContentString,
+ responseJsonParser,
+ new RateLimitSettings(10_000)
+ );
+
+ // task settings
+ CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue), false);
+
+ // secret settings
+ CustomSecretSettings secretSettings = new CustomSecretSettings(Map.of(secretSettingsKey, secretSettingsValue));
+
+ return CustomModelTests.createModel("service", taskType, serviceSettings, taskSettings, secretSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java
new file mode 100644
index 0000000000000..357a558bd493f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.core.Tuple.tuple;
+import static org.hamcrest.Matchers.is;
+
+public class CustomSecretSettingsTests extends AbstractWireSerializingTestCase {
+ public static CustomSecretSettings createRandom() {
+ var secretParameters = randomBoolean()
+ ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5)))
+ : null;
+ return new CustomSecretSettings(secretParameters);
+ }
+
+ public void testFromMap() {
+ Map secretParameters = new HashMap<>(
+ Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))
+ );
+
+ MatcherAssert.assertThat(
+ CustomSecretSettings.fromMap(secretParameters),
+ is(new CustomSecretSettings(Map.of("test_key", "test_value")))
+ );
+ }
+
+ public void testXContent() throws IOException {
+ var entity = new CustomSecretSettings(Map.of("test_key", "test_value"));
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("{\"secret_parameters\":{\"test_key\":\"test_value\"}}"));
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CustomSecretSettings::new;
+ }
+
+ @Override
+ protected CustomSecretSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CustomSecretSettings mutateInstance(CustomSecretSettings instance) {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
new file mode 100644
index 0000000000000..92ad4d3dbd26e
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
@@ -0,0 +1,291 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import io.netty.handler.codec.http.HttpMethod;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.core.Tuple.tuple;
+import static org.hamcrest.Matchers.is;
+
+public class CustomServiceSettingsTests extends AbstractWireSerializingTestCase {
+ public static CustomServiceSettings createRandom(String inputUrl, String inputPath, String inputQueryString) {
+ List taskTypeStrList = Arrays.stream(TaskType.values()).map(TaskType::toString).toList();
+ TaskType taskType = TaskType.fromString(randomFrom(taskTypeStrList));
+
+ SimilarityMeasure similarityMeasure = null;
+ Integer dims = null;
+ var isTextEmbeddingModel = taskType.equals(TaskType.TEXT_EMBEDDING);
+ if (isTextEmbeddingModel) {
+ similarityMeasure = SimilarityMeasure.DOT_PRODUCT;
+ dims = 1536;
+ }
+ Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
+ String description = randomAlphaOfLength(15);
+ String version = randomAlphaOfLength(5);
+ String url = inputUrl != null ? inputUrl : randomAlphaOfLength(15);
+ String path = inputPath != null ? inputPath : randomAlphaOfLength(15);
+ String method = randomFrom(HttpMethod.PUT.name(), HttpMethod.POST.name(), HttpMethod.GET.name());
+ String queryString = inputQueryString != null ? inputQueryString : randomAlphaOfLength(15);
+ Map headers = randomBoolean() ? null : Map.of("key", "value");
+ String requestFormat = randomFrom(CustomServiceSettings.REQUEST_FORMAT_JSON, CustomServiceSettings.REQUEST_FORMAT_STRING);
+ Map requestContent = randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), randomAlphaOfLength(5)));
+ String requestContentString = randomBoolean() ? null : randomAlphaOfLength(10);
+
+ Map textEmbeddingJsonParserMap = new HashMap<>(
+ Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ );
+ Map sparseEmbeddingJsonParserMap = new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_EMBEDDING_RESULT,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_RESULT_PATH,
+ "$.result.sparse_embeddings[*]",
+ CustomServiceSettings.SPARSE_RESULT_VALUE,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN,
+ "$.embedding[*].token_id",
+ CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT,
+ "$.embedding[*].weights"
+ )
+ )
+ )
+ )
+ )
+ );
+ Map rerankJsonParserMap = new HashMap<>(
+ Map.of(
+ CustomServiceSettings.RERANK_PARSER_INDEX,
+ "$.result.reranked_results[*].index",
+ CustomServiceSettings.RERANK_PARSER_SCORE,
+ "$.result.reranked_results[*].relevance_score",
+ CustomServiceSettings.RERANK_PARSER_DOCUMENT_TEXT,
+ "$.result.reranked_results[*].document_text"
+ )
+ );
+ Map completionJsonParserMap = new HashMap<>(
+ Map.of(CustomServiceSettings.COMPLETION_PARSER_RESULT, "$.result.text")
+ );
+
+ ResponseJsonParser responseJsonParser = switch (taskType) {
+ case TEXT_EMBEDDING -> new ResponseJsonParser(taskType, textEmbeddingJsonParserMap, new ValidationException());
+ case SPARSE_EMBEDDING -> new ResponseJsonParser(taskType, sparseEmbeddingJsonParserMap, new ValidationException());
+ case RERANK -> new ResponseJsonParser(taskType, rerankJsonParserMap, new ValidationException());
+ case COMPLETION -> new ResponseJsonParser(taskType, completionJsonParserMap, new ValidationException());
+ default -> null;
+ };
+
+ RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000));
+
+ return new CustomServiceSettings(
+ similarityMeasure,
+ dims,
+ maxInputTokens,
+ description,
+ version,
+ taskType.name(),
+ url,
+ path,
+ method,
+ queryString,
+ headers,
+ requestFormat,
+ requestContent,
+ requestContentString,
+ responseJsonParser,
+ rateLimitSettings
+ );
+ }
+
+ public static CustomServiceSettings createRandom() {
+ return createRandom(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5));
+ }
+
+ public void testFromMap() {
+ String similarity = SimilarityMeasure.DOT_PRODUCT.toString();
+ Integer dims = 1536;
+ Integer maxInputTokens = 512;
+ String description = "test fromMap";
+ String version = "v1";
+ String serviceType = TaskType.TEXT_EMBEDDING.toString();
+ String url = "http://www.abc.com";
+ String path = "/endpoint";
+ String method = HttpMethod.POST.name();
+ String queryString = "?query=test";
+ Map headers = Map.of("key", "value");
+ String requestFormat = CustomServiceSettings.REQUEST_FORMAT_STRING;
+ String requestContentString = "request body";
+
+ Map jsonParserMap = new HashMap<>(
+ Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ );
+ ResponseJsonParser responseJsonParser = new ResponseJsonParser(TaskType.TEXT_EMBEDDING, jsonParserMap, new ValidationException());
+
+ var settings = CustomServiceSettings.fromMap(
+ new HashMap<>(
+ Map.of(
+ ServiceFields.SIMILARITY,
+ similarity,
+ ServiceFields.DIMENSIONS,
+ dims,
+ ServiceFields.MAX_INPUT_TOKENS,
+ maxInputTokens,
+ CustomServiceSettings.DESCRIPTION,
+ description,
+ CustomServiceSettings.VERSION,
+ version,
+ CustomServiceSettings.URL,
+ url,
+ CustomServiceSettings.PATH,
+ new HashMap<>(
+ Map.of(
+ path,
+ new HashMap<>(
+ Map.of(
+ method,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.QUERY_STRING,
+ queryString,
+ CustomServiceSettings.HEADERS,
+ headers,
+ CustomServiceSettings.REQUEST,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.REQUEST_FORMAT,
+ requestFormat,
+ CustomServiceSettings.REQUEST_CONTENT,
+ requestContentString
+ )
+ ),
+ CustomServiceSettings.RESPONSE,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.JSON_PARSER,
+ new HashMap<>(
+ Map.of(
+ CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS,
+ "$.result.embeddings[*].embedding"
+ )
+ )
+ )
+ )
+ )
+ )
+ )
+ )
+ )
+ )
+ )
+ ),
+ null,
+ TaskType.TEXT_EMBEDDING
+ );
+
+ MatcherAssert.assertThat(
+ settings,
+ is(
+ new CustomServiceSettings(
+ SimilarityMeasure.DOT_PRODUCT,
+ dims,
+ maxInputTokens,
+ description,
+ version,
+ serviceType,
+ url,
+ path,
+ method,
+ queryString,
+ headers,
+ requestFormat,
+ null,
+ requestContentString,
+ responseJsonParser,
+ new RateLimitSettings(10_000)
+ )
+ )
+ );
+ }
+
+ public void testXContent() throws IOException {
+ Map jsonParserMap = new HashMap<>(
+ Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
+ );
+
+ ResponseJsonParser responseJsonParser = new ResponseJsonParser(TaskType.TEXT_EMBEDDING, jsonParserMap, new ValidationException());
+
+ var entity = new CustomServiceSettings(
+ null,
+ null,
+ null,
+ "test fromMap",
+ "v1",
+ TaskType.TEXT_EMBEDDING.toString(),
+ "http://www.abc.com",
+ "/endpoint",
+ HttpMethod.POST.name(),
+ "?query=test",
+ Map.of("key", "value"),
+ CustomServiceSettings.REQUEST_FORMAT_STRING,
+ null,
+ "request body",
+ responseJsonParser,
+ null
+ );
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(
+ xContentResult,
+ is(
+ "{\"description\":\"test fromMap\",\"version\":\"v1\","
+ + "\"url\":\"http://www.abc.com\",\"path\":{\"/endpoint\":{\"POST\":{\"query_string\":\"?query=test\","
+ + "\"headers\":{\"key\":\"value\"},\"request\":{\"format\":\"string\",\"content\":\"request body\"},"
+ + "\"response\":{\"json_parser\":{\"text_embeddings\":\"$.result.embeddings[*].embedding\"}}}}},"
+ + "\"rate_limit\":{\"requests_per_minute\":10000}}"
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CustomServiceSettings::new;
+ }
+
+ @Override
+ protected CustomServiceSettings createTestInstance() {
+ return createRandom(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5));
+ }
+
+ @Override
+ protected CustomServiceSettings mutateInstance(CustomServiceSettings instance) {
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java
new file mode 100644
index 0000000000000..bd7db6410da48
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.core.Tuple.tuple;
+import static org.hamcrest.Matchers.is;
+
+public class CustomTaskSettingsTests extends AbstractWireSerializingTestCase {
+ public static CustomTaskSettings createRandom() {
+ var parameters = randomBoolean() ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5))) : null;
+ var ignorePlaceholderCheck = randomBoolean();
+ return new CustomTaskSettings(parameters, ignorePlaceholderCheck);
+ }
+
+ public void testFromMap() {
+ Map taskSettingsMap = new HashMap<>(
+ Map.of(
+ CustomTaskSettings.PARAMETERS,
+ new HashMap<>(Map.of("test_key", "test_value")),
+ CustomTaskSettings.IGNORE_PLACEHOLDER_CHECK,
+ true
+ )
+ );
+
+ MatcherAssert.assertThat(
+ CustomTaskSettings.fromMap(taskSettingsMap),
+ is(new CustomTaskSettings(Map.of("test_key", "test_value"), true))
+ );
+ }
+
+ public void testXContent() throws IOException {
+ var entity = new CustomTaskSettings(Map.of("test_key", "test_value"), true);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("{\"parameters\":{\"test_key\":\"test_value\"},\"ignore_placeholder_check\":true}"));
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return CustomTaskSettings::new;
+ }
+
+ @Override
+ protected CustomTaskSettings createTestInstance() {
+ return createRandom();
+ }
+
+ @Override
+ protected CustomTaskSettings mutateInstance(CustomTaskSettings instance) throws IOException {
+ return null;
+ }
+
+ public static Map getTaskSettingsMap(
+ @Nullable Map parameters,
+ @Nullable Boolean ignorePlaceholderCheck
+ ) {
+ var map = new HashMap();
+ if (parameters != null) {
+ map.put(CustomTaskSettings.PARAMETERS, parameters);
+ }
+ if (ignorePlaceholderCheck != null) {
+ map.put(CustomTaskSettings.IGNORE_PLACEHOLDER_CHECK, ignorePlaceholderCheck);
+ }
+
+ return map;
+ }
+}