diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 010ac5528f46e..fc9f4ccfb63f4 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -11,6 +11,7 @@ + diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 4d94ff717f205..c3c34636d70b8 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -185,6 +185,7 @@ static TransportVersion def(int id) { public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM = def(9_023_0_00); public static final TransportVersion MAX_OPERATION_SIZE_REJECTIONS_ADDED = def(9_024_0_00); public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00); + public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_026_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index 73a0e3cc8a774..abac5eee3ae12 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -25,6 +25,12 @@ public enum TaskType implements Writeable { SPARSE_EMBEDDING, RERANK, COMPLETION, + CUSTOM { + @Override + public boolean isAnyOrSame(TaskType other) { + return true; + } + }, ANY { @Override public boolean isAnyOrSame(TaskType other) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index dc177795af76a..cce33232724ea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -199,6 +199,14 @@ public ActionRequestValidationException validate() { } } + if (taskType.equals(TaskType.CUSTOM)) { + if (query == null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [query] cannot be null for task type [%s]", TaskType.CUSTOM)); + return e; + } + } + return null; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java new file mode 100644 index 0000000000000..f91a4c8cb9fdd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/CustomServiceResults.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.core.ml.inference.results.CustomResults; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +public class CustomServiceResults implements InferenceServiceResults { + public static final String NAME = "custom_service_results"; + public static final String CUSTOM_TYPE = TaskType.CUSTOM.name().toLowerCase(Locale.ROOT); + + Map data; + + public CustomServiceResults(Map data) { + this.data = data; + } + + public CustomServiceResults(StreamInput in) throws IOException { + this.data = in.readGenericMap(); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.object(CUSTOM_TYPE, this.asMap()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(data); + } + + @Override + public List transformToCoordinationFormat() { + return transformToLegacyFormat(); + } + + @Override + public List transformToLegacyFormat() { + return List.of(new CustomResults(data)); + } + + @Override + public Map asMap() { + return data; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(NAME); + sb.append(Integer.toHexString(hashCode())); + sb.append("\n"); + sb.append(this.asMap().toString()); + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomServiceResults that = (CustomServiceResults) o; + return data.equals(that.data); + } + + @Override + public int hashCode() { + return Objects.hash(data); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java new file mode 100644 index 0000000000000..f91edcdebec9a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/CustomResults.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +public class CustomResults implements InferenceResults { + public static final String NAME = "custom_results"; + public static final String CUSTOM_TYPE = TaskType.CUSTOM.toString(); + + Map data; + + public CustomResults(Map data) { + this.data = data; + } + + public CustomResults(StreamInput in) throws IOException { + this.data = in.readGenericMap(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(CUSTOM_TYPE, this.asMap()); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericMap(data); + } + + @Override + public String getResultsField() { + return CUSTOM_TYPE; + } + + @Override + public Map asMap() { + return data; + } + + @Override + public Map asMap(String outputField) { + Map map = new LinkedHashMap<>(); + map.put(outputField, this.asMap()); + return map; + } + + @Override + public Map predictedValue() { + return this.asMap(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomResults that = (CustomResults) o; + return data.equals(that.data); + } + + @Override + public int hashCode() { + return Objects.hash(data); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index e9f4df7a523ad..0ed6007c3a783 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -102,6 +102,21 @@ public void testValidation_Rerank() { assertNull(e); } + public void testValidation_Custom() { + InferenceAction.Request request = new InferenceAction.Request( + TaskType.CUSTOM, + "model", + "query", + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException e = request.validate(); + assertNull(e); + } + public void testValidation_TextEmbedding_Null() { InferenceAction.Request inputNullRequest = new InferenceAction.Request( TaskType.TEXT_EMBEDDING, @@ -166,6 +181,37 @@ public void testValidation_Rerank_Empty() { assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];")); } + public void testValidation_Custom_Null() { + InferenceAction.Request queryNullRequest = new InferenceAction.Request( + TaskType.CUSTOM, + "model", + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryNullError = queryNullRequest.validate(); + assertNotNull(queryNullError); + assertThat(queryNullError.getMessage(), is("Validation Failed: 1: Field [query] cannot be null for task type [custom];")); + } + + public void testValidation_Custom_Empty() { + InferenceAction.Request queryNullRequest = new InferenceAction.Request( + TaskType.CUSTOM, + "model", + "", + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException e = queryNullRequest.validate(); + assertNull(e); + } + public void testParseRequest_DefaultsInputTypeToIngest() throws IOException { String singleInputRequest = """ { diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index cb3ccbd171304..9ce4609f254a3 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -38,6 +38,7 @@ dependencies { clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') api "com.ibm.icu:icu4j:${versions.icu4j}" + api "org.apache.commons:commons-lang3:${versions.commons_lang3}" runtimeOnly 'com.google.guava:guava:32.0.1-jre' implementation 'com.google.code.gson:gson:2.10' @@ -57,6 +58,11 @@ dependencies { implementation 'io.grpc:grpc-context:1.49.2' implementation 'io.opencensus:opencensus-api:0.31.1' implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' + implementation 'org.apache.commons:commons-text:1.4' + implementation 'com.jayway.jsonpath:json-path:2.9.0' + implementation 'net.minidev:json-smart:2.5.2' + implementation 'net.minidev:accessors-smart:2.5.2' + /* AWS SDK v2 */ implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}") diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 78f30e7da0670..11dd5997765da 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -35,6 +35,10 @@ requires org.reactivestreams; requires org.elasticsearch.logging; requires org.elasticsearch.sslconfig; + requires org.apache.commons.text; + requires json.path; + requires unboundid.ldapsdk; + requires json.smart; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index d57a6b86e4e71..61272dcd95635 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -58,6 +58,8 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; @@ -149,6 +151,7 @@ public static List getNamedWriteables() { addAlibabaCloudSearchNamedWriteables(namedWriteables); addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); + addCustomWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -663,4 +666,11 @@ private static void addEisNamedWriteables(List nam ) ); } + + private static void addCustomWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new) + ); + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new)); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 714829c08b041..5d8eeb5f2bad7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -115,6 +115,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.custom.CustomService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; @@ -358,6 +359,7 @@ public List getInferenceServiceFactories() { context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), + context -> new CustomService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java new file mode 100644 index 0000000000000..4754c321f27b4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomAction.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.custom; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.CustomRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class CustomAction implements ExecutableAction { + private final CustomModel model; + private final String failedToSendRequestErrorMessage; + private final Sender sender; + private final CustomRequestManager requestManager; + + public CustomAction(Sender sender, CustomModel model, ServiceComponents serviceComponents) { + this.model = Objects.requireNonNull(model); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Custom Search"); + this.sender = Objects.requireNonNull(sender); + this.requestManager = CustomRequestManager.of(model, serviceComponents.threadPool()); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java new file mode 100644 index 0000000000000..55c5cbfd0f731 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionCreator.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.custom; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the custom model type. + */ +public class CustomActionCreator implements CustomActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public CustomActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(CustomModel model, Map taskSettings, InputType inputType) { + var overriddenModel = CustomModel.of(model, taskSettings); + + return new CustomAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java new file mode 100644 index 0000000000000..2c66e008671b3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/custom/CustomActionVisitor.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.custom; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; + +import java.util.Map; + +public interface CustomActionVisitor { + ExecutableAction create(CustomModel model, Map taskSettings, InputType inputType); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java new file mode 100644 index 0000000000000..fb22ac0832ff8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/custom/CustomResponseHandler.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.custom; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.custom.CustomErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +/** + * Defines how to handle various errors returned from the custom integration. + */ +public class CustomResponseHandler extends BaseResponseHandler { + public CustomResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, CustomErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode >= 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java new file mode 100644 index 0000000000000..da76425c93e98 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CustomRequestManager.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.custom.CustomResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.custom.CustomRequest; +import org.elasticsearch.xpack.inference.external.response.custom.CustomResponseEntity; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class CustomRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(CustomRequestManager.class); + + private static final ResponseHandler HANDLER = createCustomHandler(); + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(CustomModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().hashCode()); + } + } + + private static ResponseHandler createCustomHandler() { + return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse); + } + + public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) { + return new CustomRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final CustomModel model; + + private CustomRequestManager(CustomModel model, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + String query; + List input; + if (inferenceInputs instanceof QueryAndDocsInputs) { + QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs); + query = queryAndDocsInputs.getQuery(); + input = queryAndDocsInputs.getChunks(); + } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) { + query = null; + input = chatInputs.getInputs(); + } else if (inferenceInputs instanceof DocumentsOnlyInput) { + DocumentsOnlyInput docsInputs = DocumentsOnlyInput.of(inferenceInputs); + query = null; + input = docsInputs.getInputs(); + } else { + throw InferenceInputs.createUnsupportedTypeException(inferenceInputs, InferenceInputs.class); + } + CustomRequest request = new CustomRequest(query, input, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java new file mode 100644 index 0000000000000..838dc561c6399 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequest.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.custom; + +import com.google.gson.Gson; + +import org.apache.commons.text.StringSubstitutor; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.StringEntity; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_FORMAT_JSON; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.REQUEST_FORMAT_STRING; + +public class CustomRequest implements Request { + private final long startTime; + + public static final Gson gson; + static { + gson = new Gson(); + } + + private final String QUERY = "query"; + private final String INPUT = "input"; + + private final Map customParams; + + private final CustomServiceSettings serviceSettings; + private final CustomTaskSettings taskSettings; + private final Map secretParameters; + private final String url; + private final String path; + private final String method; + private final String queryString; + private final Map headers; + private final String requestFormat; + private final Map requestContent; + private final String requestContentString; + private final URI uri; + StringSubstitutor substitutor; + private final String inferenceEntityId; + + public CustomRequest(String query, List input, CustomModel model) { + this.startTime = System.currentTimeMillis(); + Objects.requireNonNull(model); + + serviceSettings = model.getServiceSettings(); + taskSettings = model.getTaskSettings(); + secretParameters = model.getSecretSettings().getSecretParameters(); + path = serviceSettings.getPath(); + method = serviceSettings.getMethod().toUpperCase(Locale.ROOT); + queryString = serviceSettings.getQueryString(); + headers = serviceSettings.getHeaders(); + requestFormat = serviceSettings.getRequestFormat(); + requestContent = serviceSettings.getRequestContent(); + requestContentString = serviceSettings.getRequestContentString(); + url = model.getServiceSettings().getUrl(); + + Map customParamsObjectMap = new HashMap<>(); + if (secretParameters != null) { + for (String key : secretParameters.keySet()) { + Object paramValue = secretParameters.get(key); + if (paramValue instanceof SecureString) { + customParamsObjectMap.put(key, ((SecureString) paramValue).toString()); + } else { + customParamsObjectMap.put(key, paramValue); + } + } + } + + customParams = new HashMap(); + if (taskSettings.getParameters() != null && taskSettings.getParameters().isEmpty() == false) { + Map taskParams = getParameterMap(taskSettings.getParameters()); + for (String key : taskParams.keySet()) { + customParams.put(key, taskParams.get(key)); + } + } + + // if user's custom parameters contain input and query, it will be replaced by inference's input and query + if (query != null) { + customParamsObjectMap.put(QUERY, query); + } + + String serviceType = serviceSettings.getServiceType(); + TaskType taskType = TaskType.fromStringOrStatusException(serviceType); + if (taskType.equals(TaskType.COMPLETION)) { + if (input.size() == 1) { + customParamsObjectMap.put(INPUT, input.get(0)); + } else { + customParamsObjectMap.put(INPUT, input); + } + } else { + customParamsObjectMap.put(INPUT, input); + } + customParams.putAll(getParameterMap(customParamsObjectMap)); + + substitutor = new StringSubstitutor(customParams, "${", "}"); + + uri = buildUri(); + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpRequestBase httpRequest; + if (method.equalsIgnoreCase(HttpGet.METHOD_NAME)) { + httpRequest = new HttpGet(uri); + } else if (method.equalsIgnoreCase(HttpPost.METHOD_NAME)) { + httpRequest = new HttpPost(uri); + } else if (method.equalsIgnoreCase(HttpPut.METHOD_NAME)) { + httpRequest = new HttpPut(uri); + } else { + throw new IllegalArgumentException("unsupported http method [" + method + "], support GET, PUT and POST"); + } + + setHeaders(httpRequest); + setRequestContent(httpRequest); + + return new HttpRequest(httpRequest, getInferenceEntityId()); + } + + private void setHeaders(HttpRequestBase httpRequest) { + // Header content_type's default value, if user defines the Content-Type, it will be replaced by user's value; + httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + if (headers != null && headers.isEmpty() == false) { + for (String key : headers.keySet()) { + String headersValue = (String) headers.get(key); + String replacedHeadersValue = substitutor.replace(headersValue); + placeholderValidation(replacedHeadersValue); + httpRequest.setHeader(key, replacedHeadersValue); + } + } + } + + private void setRequestContent(HttpRequestBase httpRequest) { + switch (requestFormat.toLowerCase()) { + case REQUEST_FORMAT_JSON: { + // todo: support json format request + break; + } + case REQUEST_FORMAT_STRING: { + if (requestContentString != null && (method.equals(HttpPost.METHOD_NAME) || method.equals(HttpPut.METHOD_NAME))) { + String replacedRequestContentString = substitutor.replace(requestContentString); + placeholderValidation(replacedRequestContentString); + StringEntity stringEntity = new StringEntity(replacedRequestContentString, StandardCharsets.UTF_8); + if (httpRequest instanceof HttpPost) { + ((HttpPost) httpRequest).setEntity(stringEntity); + } else if (httpRequest instanceof HttpPut) { + ((HttpPut) httpRequest).setEntity(stringEntity); + } + } + break; + } + } + + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return uri; + } + + public CustomServiceSettings getServiceSettings() { + return serviceSettings; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + URI buildUri() { + try { + String uri = url + path; + if (queryString != null) { + String replacedQueryString = substitutor.replace(queryString); + placeholderValidation(replacedQueryString); + uri = uri + replacedQueryString; + } + return new URI(uri); + } catch (URISyntaxException e) { + // using bad request here so that potentially sensitive URL information does not get logged + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", CustomUtils.SERVICE_NAME), + RestStatus.BAD_REQUEST, + e + ); + } + } + + @SuppressWarnings("removal") + public static Map getParameterMap(Map parameterObjs) { + Map parameters = new HashMap<>(); + for (String key : parameterObjs.keySet()) { + Object value = parameterObjs.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + parameters.put(key, (String) value); + } else { + parameters.put(key, gson.toJson(value)); + } + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + + private void placeholderValidation(String s) throws IllegalArgumentException { + if (taskSettings.getIgnorePlaceholderCheck() != null && taskSettings.getIgnorePlaceholderCheck()) { + return; + } + String pattern = "\\$\\{.*?\\}"; + Pattern compiledPattern = Pattern.compile(pattern); + Matcher matcher = compiledPattern.matcher(s); + if (matcher.find()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "placeholder is not replaced, find placeholder in [%s]", s)); + } + } + + public long getStartTime() { + return startTime; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java new file mode 100644 index 0000000000000..6146cd1b7310e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestEntity.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.custom; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; +import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public record CustomRequestEntity(CustomServiceSettings serviceSettings, CustomTaskSettings taskSettings) implements ToXContentObject { + + public CustomRequestEntity { + Objects.requireNonNull(serviceSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Map requestContent = serviceSettings.getRequestContent(); + Map taskSettingsContent = taskSettings.getParameters(); + builder.startObject(); + { + for (String key : requestContent.keySet()) { + builder.field(key, requestContent.get(key)); + } + if (taskSettingsContent != null && taskSettingsContent.isEmpty() == false) { + for (String key : taskSettingsContent.keySet()) { + builder.field(key, taskSettingsContent.get(key)); + } + } + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java new file mode 100644 index 0000000000000..33cfcd239b1d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/custom/CustomUtils.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.custom; + +public class CustomUtils { + public static final String SERVICE_NAME = "custom-model"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java new file mode 100644 index 0000000000000..fca794f3a29f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomErrorResponseEntity.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.custom; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.util.Locale; + +public class CustomErrorResponseEntity extends ErrorResponse { + private static final Logger logger = LogManager.getLogger(CustomErrorResponseEntity.class); + + private CustomErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + + public static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + if (logger.isDebugEnabled()) { + logger.debug("Received a server error response: {}, body:{}", response.response(), responseMap.toString()); + } + + return new ErrorResponse( + String.format(Locale.ROOT, "Received a server error response: %s, body: %s", response.response(), responseMap.toString()) + ); + } catch (Exception e) { + logger.info("Parsing custom response body failed. Response: {}", response.response(), e); + return new ErrorResponse(String.format(Locale.ROOT, "Parsing custom response body failed. Response: %s", response.response())); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java new file mode 100644 index 0000000000000..7c55cf9a09c27 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntity.java @@ -0,0 +1,268 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.custom; + +import net.minidev.json.JSONArray; + +import com.jayway.jsonpath.JsonPath; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.CustomServiceResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.custom.CustomRequest; +import org.elasticsearch.xpack.inference.services.custom.ResponseJsonParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +public class CustomResponseEntity { + private static final Logger logger = LogManager.getLogger(CustomResponseEntity.class); + + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + CustomRequest customRequest = (CustomRequest) request; + String serviceType = customRequest.getServiceSettings().getServiceType(); + TaskType taskType = TaskType.fromStringOrStatusException(serviceType); + ResponseJsonParser responseJsonParser = customRequest.getServiceSettings().getResponseJsonParser(); + + InferenceServiceResults result = switch (taskType) { + case TEXT_EMBEDDING -> fromTextEmbeddingResponse(response, responseJsonParser); + case SPARSE_EMBEDDING -> fromSparseEmbeddingResponse(response, responseJsonParser); + case RERANK -> fromRerankResponse(response, responseJsonParser); + case COMPLETION -> fromCompletionResponse(response, responseJsonParser); + case CUSTOM -> fromCustomResponse(response); + default -> throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType), RestStatus.BAD_REQUEST); + }; + + logger.debug( + "Ai Search uri [{}] response: client cost [{}ms]", + request.getURI() != null ? request.getURI() : "", + System.currentTimeMillis() - customRequest.getStartTime() + ); + + return result; + } + + private static CustomServiceResults fromCustomResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + return new CustomServiceResults(jsonParser.map()); + } + } + + private static InferenceServiceResults fromTextEmbeddingResponse(HttpResult response, ResponseJsonParser responseJsonParser) { + String embeddingPath = responseJsonParser.getTextEmbeddingsPath(); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + JSONArray embeddingResults = JsonPath.read(jsonParser.map(), embeddingPath); + if (embeddingResults == null || embeddingResults.isEmpty()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "can't parse text_embeddings results from response, please check the path [%s]", + embeddingPath + ) + ); + } + + List embeddings = embeddingResults.stream() + .map( + embeddingResult -> ((List) embeddingResult).stream() + .map(obj -> ((Number) obj).floatValue()) + .collect(Collectors.toList()) + ) + .map(TextEmbeddingFloatResults.Embedding::of) + .collect(Collectors.toList()); + + return new TextEmbeddingFloatResults(embeddings); + } catch (Exception e) { + logger.error("failed to parse text_embeddings results from response:", e); + throw new IllegalArgumentException("failed to parse text_embeddings results from response:", e); + } + } + + private static InferenceServiceResults fromSparseEmbeddingResponse(HttpResult response, ResponseJsonParser responseJsonParser) { + String resultPath = responseJsonParser.getSparseResultPath(); + String tokenPath = responseJsonParser.getSparseTokenPath(); + String weightPath = responseJsonParser.getSparseWeightPath(); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + Object jsonObject = jsonParser.map(); + JSONArray sparseResults = JsonPath.read(jsonObject, resultPath); + if (sparseResults == null || sparseResults.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse sparse_result results from response, please check the path [%s]", resultPath) + ); + } + + List embeddingList = new ArrayList<>(); + + for (Object obj : sparseResults) { + JSONArray tokenResults = JsonPath.read(obj, tokenPath); + if (tokenResults == null || tokenResults.isEmpty()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "can't parse sparse embeddings results from response, please check the path [%s]", + tokenPath + ) + ); + } + JSONArray weightResults = JsonPath.read(obj, weightPath); + if (weightResults == null || weightResults.isEmpty()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "can't parse sparse embeddings results from response, please check the path [%s]", + weightPath + ) + ); + } + List tokens = tokenResults.stream().map(Object::toString).toList(); + List weights = weightResults.stream().map(weight -> ((Number) weight).floatValue()).toList(); + + List weightedTokens = new ArrayList<>(); + if (tokens.size() != weights.size()) { + throw new IllegalArgumentException("Tokens and weights size does not match"); + } + for (int i = 0; i < tokens.size(); i++) { + weightedTokens.add(new WeightedToken(tokens.get(i), weights.get(i))); + } + + embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + } + + return new SparseEmbeddingResults(embeddingList); + } catch (Exception e) { + logger.error("failed to parse sparse_result results from response:", e); + throw new IllegalArgumentException("failed to parse sparse_result results from response:", e); + } + } + + private static InferenceServiceResults fromRerankResponse(HttpResult response, ResponseJsonParser responseJsonParser) { + String indexPath = responseJsonParser.getRerankedIndexPath(); + String scorePath = responseJsonParser.getRelevanceScorePath(); + String docPath = responseJsonParser.getDocumentTextPath(); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + Object jsonObject = jsonParser.map(); + JSONArray indexResults = indexPath != null ? JsonPath.read(jsonObject, indexPath) : null; + if (indexPath != null && (indexResults == null || indexResults.isEmpty())) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse reranked_index results from response, please check the path [%s]", indexPath) + ); + } + JSONArray scoreResults = JsonPath.read(jsonObject, scorePath); + if (scoreResults == null || scoreResults.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse relevance_score results from response, please check the path [%s]", scorePath) + ); + } + JSONArray docResults = docPath != null ? JsonPath.read(jsonObject, docPath) : null; + if (docPath != null && (docResults == null || docResults.isEmpty())) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse doc results from response, please check the path [%s]", docPath) + ); + } + + List indices = indexResults != null ? indexResults.stream().map(index -> (Integer) index).toList() : null; + + List scores = scoreResults.stream().map(score -> ((Number) score).floatValue()).toList(); + + List docs = docResults != null ? docResults.stream().map(doc -> (String) doc).toList() : null; + + List rerankResults = new ArrayList<>(); + if (indices != null && indices.size() != scores.size()) { + throw new IllegalArgumentException("Indices and scores size does not match"); + } + for (int i = 0; i < scores.size(); i++) { + if (docs != null) { + if (indices != null) { + rerankResults.add(new RankedDocsResults.RankedDoc(indices.get(i), scores.get(i), docs.get(i))); + } else { + rerankResults.add(new RankedDocsResults.RankedDoc(i, scores.get(i), docs.get(i))); + } + } else { + if (indices != null) { + rerankResults.add(new RankedDocsResults.RankedDoc(indices.get(i), scores.get(i), null)); + } else { + rerankResults.add(new RankedDocsResults.RankedDoc(i, scores.get(i), null)); + } + } + } + + return new RankedDocsResults(rerankResults); + } catch (Exception e) { + logger.error("failed to parse rerank results from response:", e); + throw new IllegalArgumentException("failed to parse rerank results from response:", e); + } + } + + private static InferenceServiceResults fromCompletionResponse(HttpResult response, ResponseJsonParser responseJsonParser) { + String resultPath = responseJsonParser.getCompletionResultPath(); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + Object results = JsonPath.read(jsonParser.map(), resultPath); + if (results == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse completion results from response, please check the path [%s]", resultPath) + ); + } + if (results instanceof String) { + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result((String) results))); + } else if (results instanceof JSONArray jsonArrayResults) { + if (jsonArrayResults.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "can't parse completion results from response, please check the path [%s]", resultPath) + ); + } + List completionResults = jsonArrayResults.stream() + .map(obj -> (String) obj) + .map(ChatCompletionResults.Result::new) + .collect(Collectors.toList()); + return new ChatCompletionResults(completionResults); + } else { + throw new IllegalArgumentException("Unsupported completion result type: " + results.getClass().getName()); + } + } catch (Exception e) { + logger.error("failed to parse completion results from response:", e); + throw new IllegalArgumentException("failed to parse completion results from response:", e); + } + } + + public static String unsupportedTaskTypeErrorMsg(TaskType taskType) { + return "ai search custom service does not support task type [" + taskType + "]"; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 56bf6c1359a56..589bd3aceed6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -74,7 +74,7 @@ public void infer( private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); - case RERANK -> new QueryAndDocsInputs(query, input, stream); + case RERANK, CUSTOM -> new QueryAndDocsInputs(query, input, stream); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 7330d45b6f16c..2abc58d435cfb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -254,6 +254,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) { return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName); } + public static String mustBeNonEmptyMap(String settingName, String scope) { + return Strings.format("[%s] Invalid value empty map. [%s] must be a non-empty map", scope, settingName); + } + public static String invalidTimeValueMsg(String timeValueStr, String settingName, String scope, String exceptionMsg) { return Strings.format( "[%s] Invalid time value [%s]. [%s] must be a valid time value string: %s", @@ -427,6 +431,58 @@ public static Integer extractRequiredPositiveInteger( return field; } + @SuppressWarnings("unchecked") + public static Map extractRequiredMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map requiredField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (requiredField == null) { + validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope)); + } else if (requiredField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope)); + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return requiredField; + } + + @SuppressWarnings("unchecked") + public static Map extractOptionalMap( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Map optionalField = ServiceUtils.removeAsType(map, settingName, Map.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (optionalField != null && optionalField.isEmpty()) { + validationException.addValidationError(ServiceUtils.mustBeNonEmptyMap(settingName, scope)); + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return optionalField; + } + public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java new file mode 100644 index 0000000000000..0d68b6ebbe077 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.custom.CustomActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.util.Map; +import java.util.Objects; + +public class CustomModel extends Model { + private final CustomRateLimitServiceSettings rateLimitServiceSettings; + + public CustomModel(ModelConfigurations configurations, ModelSecrets secrets, CustomRateLimitServiceSettings rateLimitServiceSettings) { + super(configurations, secrets); + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public static CustomModel of(CustomModel model, Map taskSettings) { + var requestTaskSettings = CustomTaskSettings.fromMap(taskSettings); + return new CustomModel(model, CustomTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public CustomModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + CustomServiceSettings.fromMap(serviceSettings, context, taskType), + CustomTaskSettings.fromMap(taskSettings), + CustomSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + CustomModel( + String modelId, + TaskType taskType, + String service, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + this( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + + protected CustomModel(CustomModel model, TaskSettings taskSettings) { + super(model, taskSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + protected CustomModel(CustomModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + @Override + public CustomServiceSettings getServiceSettings() { + return (CustomServiceSettings) super.getServiceSettings(); + } + + @Override + public CustomTaskSettings getTaskSettings() { + return (CustomTaskSettings) super.getTaskSettings(); + } + + @Override + public CustomSecretSettings getSecretSettings() { + return (CustomSecretSettings) super.getSecretSettings(); + } + + public ExecutableAction accept(CustomActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } + + public CustomRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java new file mode 100644 index 0000000000000..55641bad7ccaa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRateLimitServiceSettings.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface CustomRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java new file mode 100644 index 0000000000000..2c118869a854a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; + +public class CustomSecretSettings implements SecretSettings { + public static final String NAME = "custom_secret_settings"; + public static final String SECRET_PARAMETERS = "secret_parameters"; + private final Map secretParameters; + + public static CustomSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + + Map requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException); + if (requestSecretParamsMap == null) { + return null; + } else { + Map secureSecretParameters = new HashMap<>(); + for (String paramKey : requestSecretParamsMap.keySet()) { + Object paramValue = requestSecretParamsMap.get(paramKey); + secureSecretParameters.put(paramKey, paramValue); + } + return new CustomSecretSettings(secureSecretParameters); + } + } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + return fromMap(new HashMap<>(newSecrets)); + } + + public CustomSecretSettings(@Nullable Map secretParameters) { + this.secretParameters = secretParameters; + } + + public CustomSecretSettings(StreamInput in) throws IOException { + if (in.readBoolean()) { + secretParameters = in.readGenericMap(); + } else { + secretParameters = null; + } + } + + public Map getSecretParameters() { + return secretParameters; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (secretParameters != null) { + builder.field(SECRET_PARAMETERS, secretParameters); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_15_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (secretParameters != null) { + out.writeBoolean(true); + out.writeGenericMap(secretParameters); + } else { + out.writeBoolean(false); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomSecretSettings that = (CustomSecretSettings) o; + return Objects.equals(secretParameters, that.secretParameters); + } + + @Override + public int hashCode() { + return Objects.hash(secretParameters); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java new file mode 100644 index 0000000000000..a36f47af8e17b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.inference.external.action.custom.CustomActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.custom.CustomUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; + +public class CustomService extends SenderService { + private static final Logger logger = LogManager.getLogger(CustomService.class); + public static final String NAME = CustomUtils.SERVICE_NAME; + + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK, + TaskType.COMPLETION, + TaskType.CUSTOM + ); + + public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + CustomModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + logModelConfig(model.getConfigurations()); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + private static CustomModel createModelWithoutLoggingDeprecations( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static CustomModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + } + + @Override + public CustomModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelWithoutLoggingDeprecations( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof CustomModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + CustomModel customModel = (CustomModel) model; + + var actionCreator = new CustomActionCreator(getSender(), getServiceComponents()); + + var action = customModel.accept(actionCreator, taskSettings, inputType); + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + listener.onResponse(model); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(NAME) + .setTaskTypes(supportedTaskTypes) + .setConfigurations(configurationMap) + .build(); + } + ); + } + + private void logModelConfig(ModelConfigurations modelConfigurations) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + XContentBuilder modelBuilder = modelConfigurations.toXContent(builder, EMPTY_PARAMS); + String jsonString = BytesReference.bytes(modelBuilder).utf8ToString(); + logger.info("add custom model: " + jsonString); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java new file mode 100644 index 0000000000000..3399307113bb1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java @@ -0,0 +1,635 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.custom.CustomUtils.SERVICE_NAME; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings { + public static final String NAME = "custom_service_settings"; + public static final String DESCRIPTION = "description"; + public static final String VERSION = "version"; + public static final String URL = "url"; + public static final String PATH = "path"; + public static final String QUERY_STRING = "query_string"; + public static final String HEADERS = "headers"; + public static final String REQUEST = "request"; + public static final String REQUEST_FORMAT = "format"; + public static final String REQUEST_CONTENT = "content"; + public static final String RESPONSE = "response"; + public static final String JSON_PARSER = "json_parser"; + + public static final String TEXT_EMBEDDING_PARSER_EMBEDDINGS = "text_embeddings"; + + public static final String SPARSE_EMBEDDING_RESULT = "sparse_result"; + public static final String SPARSE_RESULT_PATH = "path"; + public static final String SPARSE_RESULT_VALUE = "value"; + public static final String SPARSE_EMBEDDING_PARSER_TOKEN = "sparse_token"; + public static final String SPARSE_EMBEDDING_PARSER_WEIGHT = "sparse_weight"; + + public static final String RERANK_PARSER_INDEX = "reranked_index"; + public static final String RERANK_PARSER_SCORE = "relevance_score"; + public static final String RERANK_PARSER_DOCUMENT_TEXT = "document_text"; + + public static final String COMPLETION_PARSER_RESULT = "completion_result"; + + public static final String REQUEST_FORMAT_JSON = "json"; + public static final String REQUEST_FORMAT_STRING = "string"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + + public static CustomServiceSettings fromMap(Map map, ConfigurationParseContext context, TaskType taskType) { + ValidationException validationException = new ValidationException(); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + String description = extractOptionalString(map, DESCRIPTION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String version = extractOptionalString(map, VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String serviceType = taskType.toString(); + String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Map pathMap = extractRequiredMap(map, PATH, ModelConfigurations.SERVICE_SETTINGS, validationException); + if (pathMap == null) { + throw validationException; + } + if (pathMap.size() > 1) { + validationException.addValidationError("[" + PATH + "] only support one endpoint, but found [" + pathMap.keySet() + "]"); + throw validationException; + } + + String path = pathMap.keySet().iterator().next(); + Map pathContent = extractRequiredMap(pathMap, path, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (pathContent == null) { + throw validationException; + } + if (pathContent.size() > 1) { + validationException.addValidationError("[" + PATH + "] only support one method, but found [" + pathContent.keySet() + "]"); + throw validationException; + } + + String method = pathContent.keySet().iterator().next(); + switch (method.toUpperCase()) { + case HttpGet.METHOD_NAME: + case HttpPut.METHOD_NAME: + case HttpPost.METHOD_NAME: + break; + default: + validationException.addValidationError( + String.format( + Locale.ROOT, + "unsupported http method [" + method + "], support [%s], [%s] and [%s]", + HttpGet.METHOD_NAME, + HttpPut.METHOD_NAME, + HttpPost.METHOD_NAME + ) + ); + } + + Map modelParamsMap = extractRequiredMap( + pathContent, + method, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + if (modelParamsMap == null) { + throw validationException; + } + + String queryString = extractOptionalString(modelParamsMap, QUERY_STRING, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Map headers = extractOptionalMap( + modelParamsMap, + HEADERS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + Map requestBodyMap = extractRequiredMap( + modelParamsMap, + REQUEST, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (requestBodyMap == null) { + throw validationException; + } + + String requestFormat = extractRequiredString( + requestBodyMap, + REQUEST_FORMAT, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + if (requestFormat == null) { + throw validationException; + } + + Map requestContent = null; + String requestContentString = null; + + switch (requestFormat) { + // todo support json format + case REQUEST_FORMAT_STRING: + requestContentString = extractRequiredString( + requestBodyMap, + REQUEST_CONTENT, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + break; + default: + validationException.addValidationError( + Strings.format( + "[%s.%s] does not support value [%s]. It should be [%s]", + REQUEST, + REQUEST_FORMAT, + requestFormat, + REQUEST_FORMAT_STRING + ) + ); + } + + ResponseJsonParser responseJsonParser = null; + Map responseParserMap = null; + switch (taskType) { + case TEXT_EMBEDDING: + case SPARSE_EMBEDDING: + case RERANK: + case COMPLETION: + responseParserMap = extractRequiredMap(modelParamsMap, RESPONSE, ModelConfigurations.SERVICE_SETTINGS, validationException); + if (responseParserMap == null) { + throw validationException; + } + Map jsonParserMap = extractRequiredMap( + responseParserMap, + JSON_PARSER, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + if (jsonParserMap == null) { + throw validationException; + } + responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException); + throwIfNotEmptyMap(responseParserMap, NAME); + } + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + CustomService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomServiceSettings( + similarity, + dims, + maxInputTokens, + description, + version, + serviceType, + url, + path, + method, + queryString, + headers, + requestFormat, + requestContent, + requestContentString, + responseJsonParser, + rateLimitSettings + ); + } + + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final String description; + private final String version; + private final String url; + private final String serviceType; + private final String path; + private final String method; + private final String queryString; + private final Map headers; + private final String requestFormat; + private final Map requestContent; + private final String requestContentString; + private final ResponseJsonParser responseJsonParser; + private final RateLimitSettings rateLimitSettings; + + public CustomServiceSettings( + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + String description, + String version, + String serviceType, + String url, + String path, + String method, + String queryString, + Map headers, + String requestFormat, + Map requestContent, + String requestContentString, + ResponseJsonParser responseJsonParser, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.description = description; + this.version = version; + this.serviceType = serviceType; + this.url = url; + this.path = path; + this.method = method; + this.queryString = queryString; + this.headers = headers; + this.requestFormat = requestFormat; + this.requestContent = requestContent; + this.requestContentString = requestContentString; + this.responseJsonParser = responseJsonParser; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public CustomServiceSettings(StreamInput in) throws IOException { + similarity = in.readOptionalEnum(SimilarityMeasure.class); + dimensions = in.readOptionalVInt(); + maxInputTokens = in.readOptionalVInt(); + description = in.readOptionalString(); + version = in.readOptionalString(); + serviceType = in.readString(); + url = in.readString(); + path = in.readString(); + method = in.readString(); + queryString = in.readOptionalString(); + if (in.readBoolean()) { + headers = in.readGenericMap(); + } else { + headers = null; + } + requestFormat = in.readString(); + if (in.readBoolean()) { + requestContent = in.readGenericMap(); + } else { + requestContent = null; + } + requestContentString = in.readOptionalString(); + if (in.readBoolean()) { + responseJsonParser = new ResponseJsonParser(in); + } else { + responseJsonParser = null; + } + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + public URI getUri() { + return null; + } + + public SimilarityMeasure getSimilarity() { + return similarity; + } + + public Integer getDimensions() { + return dimensions; + } + + public Integer getMaxInputTokens() { + return maxInputTokens; + } + + public String getUrl() { + return url; + } + + public String getServiceType() { + return serviceType; + } + + public String getPath() { + return path; + } + + public String getMethod() { + return method; + } + + public String getQueryString() { + return queryString; + } + + public Map getHeaders() { + return headers; + } + + public String getRequestFormat() { + return requestFormat; + } + + public Map getRequestContent() { + return requestContent; + } + + public String getRequestContentString() { + return requestContentString; + } + + public ResponseJsonParser getResponseJsonParser() { + return responseJsonParser; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (description != null) { + builder.field(DESCRIPTION, description); + } + if (version != null) { + builder.field(VERSION, version); + } + if (url != null) { + builder.field(URL, url); + } + builder.startObject(PATH); + { + builder.startObject(path); + { + builder.startObject(method); + { + if (queryString != null) { + builder.field(QUERY_STRING, queryString); + } + if (headers != null) { + builder.field(HEADERS, headers); + } + + builder.startObject(REQUEST); + { + if (requestFormat != null) { + builder.field(REQUEST_FORMAT, requestFormat); + } + if (requestContent != null) { + builder.field(REQUEST_CONTENT, requestContent); + } else if (requestContentString != null) { + builder.field(REQUEST_CONTENT, requestContentString); + } + } + builder.endObject(); + + if (responseJsonParser != null) { + builder.startObject(RESPONSE); + { + responseJsonParser.toXContent(builder, params); + } + builder.endObject(); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endObject(); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ADD_INFERENCE_CUSTOM_MODEL; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(similarity); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalString(description); + out.writeOptionalString(version); + out.writeString(serviceType); + out.writeString(url); + out.writeString(path); + out.writeString(method); + out.writeOptionalString(queryString); + if (headers != null) { + out.writeBoolean(true); + out.writeGenericMap(headers); + } else { + out.writeBoolean(false); + } + out.writeString(requestFormat); + if (requestContent != null) { + out.writeBoolean(true); + out.writeGenericMap(requestContent); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(requestContentString); + if (responseJsonParser != null) { + out.writeBoolean(true); + responseJsonParser.writeTo(out); + } else { + out.writeBoolean(false); + } + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomServiceSettings that = (CustomServiceSettings) o; + return Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(description, that.description) + && Objects.equals(version, that.version) + && Objects.equals(serviceType, that.serviceType) + && Objects.equals(url, that.url) + && Objects.equals(path, that.path) + && Objects.equals(method, that.method) + && Objects.equals(queryString, that.queryString) + && Objects.equals(headers, that.headers) + && Objects.equals(requestFormat, that.requestFormat) + && Objects.equals(requestContent, that.requestContent) + && Objects.equals(requestContentString, that.requestContentString) + && Objects.equals(responseJsonParser, that.responseJsonParser) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash( + similarity, + dimensions, + maxInputTokens, + description, + version, + serviceType, + url, + path, + method, + queryString, + headers, + requestFormat, + requestContent, + requestContentString, + responseJsonParser, + rateLimitSettings + ); + } + + @Override + public String modelId() { + return SERVICE_NAME; + } + + private static ResponseJsonParser extractResponseParser( + TaskType taskType, + Map responseParserMap, + ValidationException validationException + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> extractTextEmbeddingResponseParser(responseParserMap, validationException); + case SPARSE_EMBEDDING -> extractSparseTextEmbeddingResponseParser(responseParserMap, validationException); + case RERANK -> extractRerankResponseParser(responseParserMap, validationException); + case COMPLETION -> extractCompletionResponseParser(responseParserMap, validationException); + default -> throw new IllegalArgumentException( + String.format(Locale.ROOT, "response json parser does not support TaskType [%s]", taskType) + ); + }; + } + + private static ResponseJsonParser extractTextEmbeddingResponseParser( + Map responseParserMap, + ValidationException validationException + ) { + return new ResponseJsonParser(TaskType.TEXT_EMBEDDING, responseParserMap, validationException); + } + + private static ResponseJsonParser extractSparseTextEmbeddingResponseParser( + Map responseParserMap, + ValidationException validationException + ) { + return new ResponseJsonParser(TaskType.SPARSE_EMBEDDING, responseParserMap, validationException); + } + + private static ResponseJsonParser extractRerankResponseParser( + Map responseParserMap, + ValidationException validationException + ) { + return new ResponseJsonParser(TaskType.RERANK, responseParserMap, validationException); + } + + private static ResponseJsonParser extractCompletionResponseParser( + Map responseParserMap, + ValidationException validationException + ) { + return new ResponseJsonParser(TaskType.COMPLETION, responseParserMap, validationException); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java new file mode 100644 index 0000000000000..9fed14a2d9a9f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettings.java @@ -0,0 +1,164 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; + +public class CustomTaskSettings implements TaskSettings { + public static final String NAME = "custom_task_settings"; + + public static final String PARAMETERS = "parameters"; + public static final String IGNORE_PLACEHOLDER_CHECK = "ignore_placeholder_check"; + + static final CustomTaskSettings EMPTY_SETTINGS = new CustomTaskSettings(null, null); + + public static CustomTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Map parameters = extractOptionalMap(map, PARAMETERS, ModelConfigurations.TASK_SETTINGS, validationException); + Boolean ignorePlaceholderCheck = extractOptionalBoolean(map, IGNORE_PLACEHOLDER_CHECK, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new CustomTaskSettings(parameters, ignorePlaceholderCheck); + } + + /** + * Creates a new {@link CustomTaskSettings} + * by preferring non-null fields from the request settings over the original settings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link CustomTaskSettings} + */ + + public static CustomTaskSettings of(CustomTaskSettings originalSettings, CustomTaskSettings requestTaskSettings) { + // If both requestTaskSettings.getParameters() and originalSettings.getParameters() are defined + // the maps should be merged. + if (originalSettings != null + && originalSettings.parameters != null + && requestTaskSettings != null + && requestTaskSettings.parameters != null) { + var copy = new HashMap<>(originalSettings.parameters); + requestTaskSettings.parameters.forEach((key, value) -> copy.merge(key, value, (originalValue, requestValue) -> requestValue)); + Boolean ignorePlaceholderCheck = requestTaskSettings.getIgnorePlaceholderCheck() != null + ? requestTaskSettings.getIgnorePlaceholderCheck() + : originalSettings.getIgnorePlaceholderCheck(); + return new CustomTaskSettings(copy, ignorePlaceholderCheck); + } else { + return new CustomTaskSettings( + requestTaskSettings.getParameters() != null ? requestTaskSettings.getParameters() : originalSettings.getParameters(), + requestTaskSettings.getIgnorePlaceholderCheck() != null + ? requestTaskSettings.getIgnorePlaceholderCheck() + : originalSettings.getIgnorePlaceholderCheck() + ); + } + } + + private final Map parameters; + private Boolean ignorePlaceholderCheck; + + public CustomTaskSettings(StreamInput in) throws IOException { + parameters = in.readBoolean() ? in.readGenericMap() : null; + ignorePlaceholderCheck = in.readOptionalBoolean(); + } + + public CustomTaskSettings(Map parameters, Boolean ignorePlaceholderCheck) { + this.parameters = parameters; + this.ignorePlaceholderCheck = ignorePlaceholderCheck; + } + + public Map getParameters() { + return parameters; + } + + public Boolean getIgnorePlaceholderCheck() { + return ignorePlaceholderCheck; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (parameters != null) { + builder.field(PARAMETERS, parameters); + } + if (ignorePlaceholderCheck != null) { + builder.field(IGNORE_PLACEHOLDER_CHECK, ignorePlaceholderCheck); + } + builder.endObject(); + return builder; + } + + public Map getParametersOrDefault() { + return parameters; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_15_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (parameters == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeGenericMap(parameters); + } + out.writeOptionalBoolean(ignorePlaceholderCheck); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CustomTaskSettings that = (CustomTaskSettings) o; + return Objects.equals(parameters, that.parameters) && Objects.equals(ignorePlaceholderCheck, that.ignorePlaceholderCheck); + } + + @Override + public int hashCode() { + return Objects.hash(parameters, ignorePlaceholderCheck); + } + + @Override + public boolean isEmpty() { + return (parameters == null || parameters.isEmpty()) && ignorePlaceholderCheck == null; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + CustomTaskSettings updatedSettings = CustomTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java new file mode 100644 index 0000000000000..6c2c1bfda882d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/ResponseJsonParser.java @@ -0,0 +1,274 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.COMPLETION_PARSER_RESULT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.JSON_PARSER; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_DOCUMENT_TEXT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_INDEX; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.RERANK_PARSER_SCORE; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_EMBEDDING_RESULT; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_RESULT_PATH; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.SPARSE_RESULT_VALUE; +import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS; + +public class ResponseJsonParser { + public static final String FIELD_NAME = JSON_PARSER; + + private String taskTypeStr; + + private String textEmbeddingsPath; + + private String sparseResultPath; + private String sparseTokenPath; + private String sparseWeightPath; + + private String rerankedIndexPath; + private String relevanceScorePath; + private String documentTextPath; + + private String completionResultPath; + + public ResponseJsonParser(TaskType taskType, Map responseParserMap, ValidationException validationException) { + this.taskTypeStr = taskType.toString(); + switch (taskType) { + case TEXT_EMBEDDING -> textEmbeddingsPath = extractRequiredString( + responseParserMap, + TEXT_EMBEDDING_PARSER_EMBEDDINGS, + JSON_PARSER, + validationException + ); + case SPARSE_EMBEDDING -> { + Map sparseResultMap = extractRequiredMap( + responseParserMap, + SPARSE_EMBEDDING_RESULT, + JSON_PARSER, + validationException + ); + if (sparseResultMap == null) { + throw validationException; + } + sparseResultPath = extractRequiredString(sparseResultMap, SPARSE_RESULT_PATH, JSON_PARSER, validationException); + Map sparseResultValueMap = extractRequiredMap( + sparseResultMap, + SPARSE_RESULT_VALUE, + JSON_PARSER, + validationException + ); + if (sparseResultValueMap == null) { + throw validationException; + } + sparseTokenPath = extractRequiredString( + sparseResultValueMap, + SPARSE_EMBEDDING_PARSER_TOKEN, + JSON_PARSER, + validationException + ); + sparseWeightPath = extractRequiredString( + sparseResultValueMap, + SPARSE_EMBEDDING_PARSER_WEIGHT, + JSON_PARSER, + validationException + ); + } + case RERANK -> { + rerankedIndexPath = extractOptionalString(responseParserMap, RERANK_PARSER_INDEX, JSON_PARSER, validationException); + + relevanceScorePath = extractRequiredString(responseParserMap, RERANK_PARSER_SCORE, JSON_PARSER, validationException); + + documentTextPath = extractOptionalString(responseParserMap, RERANK_PARSER_DOCUMENT_TEXT, JSON_PARSER, validationException); + } + case COMPLETION -> completionResultPath = extractRequiredString( + responseParserMap, + COMPLETION_PARSER_RESULT, + JSON_PARSER, + validationException + ); + default -> throw new IllegalArgumentException( + String.format(Locale.ROOT, "json parser does not support taskType [%s]", taskType) + ); + } + } + + public ResponseJsonParser(StreamInput in) throws IOException { + this.taskTypeStr = in.readString(); + TaskType taskType = TaskType.fromString(this.taskTypeStr); + switch (taskType) { + case TEXT_EMBEDDING -> this.textEmbeddingsPath = in.readString(); + case SPARSE_EMBEDDING -> { + this.sparseResultPath = in.readString(); + this.sparseTokenPath = in.readString(); + this.sparseWeightPath = in.readString(); + } + case RERANK -> { + this.rerankedIndexPath = in.readOptionalString(); + this.relevanceScorePath = in.readString(); + this.documentTextPath = in.readOptionalString(); + } + case COMPLETION -> this.completionResultPath = in.readString(); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.taskTypeStr); + TaskType taskType = TaskType.fromString(this.taskTypeStr); + switch (taskType) { + case TEXT_EMBEDDING -> out.writeString(this.textEmbeddingsPath); + case SPARSE_EMBEDDING -> { + out.writeString(this.sparseResultPath); + out.writeString(this.sparseTokenPath); + out.writeString(this.sparseWeightPath); + } + case RERANK -> { + out.writeOptionalString(this.rerankedIndexPath); + out.writeString(this.relevanceScorePath); + out.writeOptionalString(this.documentTextPath); + } + case COMPLETION -> out.writeString(this.completionResultPath); + } + } + + public String getTextEmbeddingsPath() { + return textEmbeddingsPath; + } + + public String getSparseResultPath() { + return sparseResultPath; + } + + public String getSparseTokenPath() { + return sparseTokenPath; + } + + public String getSparseWeightPath() { + return sparseWeightPath; + } + + public String getRerankedIndexPath() { + return rerankedIndexPath; + } + + public String getRelevanceScorePath() { + return relevanceScorePath; + } + + public String getDocumentTextPath() { + return documentTextPath; + } + + public String getCompletionResultPath() { + return completionResultPath; + } + + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(FIELD_NAME); + { + switch (TaskType.fromString(this.taskTypeStr)) { + case TEXT_EMBEDDING -> builder.field(TEXT_EMBEDDING_PARSER_EMBEDDINGS, textEmbeddingsPath); + case SPARSE_EMBEDDING -> { + builder.startObject(SPARSE_EMBEDDING_RESULT); + { + builder.field(SPARSE_RESULT_PATH, sparseResultPath); + builder.startObject(SPARSE_RESULT_VALUE); + { + builder.field(SPARSE_EMBEDDING_PARSER_TOKEN, sparseTokenPath); + builder.field(SPARSE_EMBEDDING_PARSER_WEIGHT, sparseWeightPath); + } + builder.endObject(); + } + builder.endObject(); + } + case RERANK -> { + if (rerankedIndexPath != null) { + builder.field(RERANK_PARSER_INDEX, rerankedIndexPath); + } + builder.field(RERANK_PARSER_SCORE, relevanceScorePath); + if (documentTextPath != null) { + builder.field(RERANK_PARSER_DOCUMENT_TEXT, documentTextPath); + } + } + case COMPLETION -> builder.field(COMPLETION_PARSER_RESULT, completionResultPath); + } + } + builder.endObject(); + return builder; + } + + public static ResponseJsonParser of( + Map map, + ValidationException validationException, + String serviceName, + ConfigurationParseContext context + ) { + Map responseParserMap = extractRequiredMap(map, FIELD_NAME, JSON_PARSER, validationException); + if (responseParserMap == null) { + throw validationException; + } + String taskTypeStr = extractRequiredString(responseParserMap, TaskType.NAME, FIELD_NAME, validationException); + if (taskTypeStr == null) { + throw validationException; + } + TaskType taskType = TaskType.fromString(taskTypeStr); + ResponseJsonParser responseJsonParser = new ResponseJsonParser(taskType, responseParserMap, validationException); + + if (ConfigurationParseContext.isRequestContext(context)) { + throwIfNotEmptyMap(responseParserMap, serviceName); + } + return responseJsonParser; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ResponseJsonParser that = (ResponseJsonParser) o; + return Objects.equals(taskTypeStr, that.taskTypeStr) + && Objects.equals(textEmbeddingsPath, that.textEmbeddingsPath) + && Objects.equals(sparseResultPath, that.sparseResultPath) + && Objects.equals(sparseTokenPath, that.sparseTokenPath) + && Objects.equals(sparseWeightPath, that.sparseWeightPath) + && Objects.equals(rerankedIndexPath, that.rerankedIndexPath) + && Objects.equals(relevanceScorePath, that.relevanceScorePath) + && Objects.equals(documentTextPath, that.documentTextPath) + && Objects.equals(completionResultPath, that.completionResultPath); + } + + @Override + public int hashCode() { + return Objects.hash( + taskTypeStr, + textEmbeddingsPath, + sparseResultPath, + sparseTokenPath, + sparseWeightPath, + rerankedIndexPath, + relevanceScorePath, + documentTextPath, + completionResultPath + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java new file mode 100644 index 0000000000000..470f5f4850ffc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/custom/CustomRequestTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.custom; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.custom.CustomModel; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.hamcrest.MatcherAssert; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CustomRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + // create request + var request = createRequest(null, List.of("abc"), CustomModelTests.getTestModel()); + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + String queryStringRes = "?query=" + CustomModelTests.taskSettingsValue; + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var uri = httpPost.getURI().toString(); + MatcherAssert.assertThat(uri, is(CustomModelTests.url + CustomModelTests.path + queryStringRes)); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(CustomModelTests.secretSettingsValue)); + + String requestBody = convertStreamToString(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestBody, is("\"input\":\"[\"abc\"]\"")); + } + + public static CustomRequest createRequest(String query, List input, CustomModel model) { + return new CustomRequest(query, input, model); + } + + private static String convertStreamToString(InputStream inputStream) { + StringBuilder stringBuilder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + while ((line = reader.readLine()) != null) { + stringBuilder.append(line); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return stringBuilder.toString(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java new file mode 100644 index 0000000000000..f73583d2a556d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/custom/CustomResponseEntityTests.java @@ -0,0 +1,229 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.custom; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.custom.CustomRequestTests; +import org.elasticsearch.xpack.inference.services.custom.CustomModelTests; +import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class CustomResponseEntityTests extends ESTestCase { + + public void testFromTextEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4", + "latency": 38, + "usage": { + "token_count": 3072 + }, + "result": { + "embeddings": [ + { + "index": 0, + "embedding": [ + -0.02868066355586052, + 0.022033605724573135 + ] + } + ] + } + } + """; + + var request = CustomRequestTests.createRequest(null, List.of("abc"), CustomModelTests.getTestModel()); + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(TextEmbeddingFloatResults.class)); + assertThat( + ((TextEmbeddingFloatResults) results).embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.02868066355586052f, 0.022033605724573135f }))) + ); + } + + public void testFromSparseEmbeddingResponse() throws IOException { + String responseJson = """ + { + "request_id": "75C50B5B-E79E-4930-****-F48DBB392231", + "latency": 22, + "usage": { + "token_count": 11 + }, + "result": { + "sparse_embeddings": [ + { + "index": 0, + "embedding": [ + { + "tokenId": 6, + "weight": 0.10137939453125 + }, + { + "tokenId": 163040, + "weight": 0.2841796875 + } + ] + } + ] + } + } + """; + + Map jsonParserMap = new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_EMBEDDING_RESULT, + new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_RESULT_PATH, + "$.result.sparse_embeddings[*]", + CustomServiceSettings.SPARSE_RESULT_VALUE, + new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN, + "$.embedding[*].tokenId", + CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT, + "$.embedding[*].weight" + ) + ) + ) + ) + ) + ); + + var request = CustomRequestTests.createRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.SPARSE_EMBEDDING, jsonParserMap) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertThat(results, instanceOf(SparseEmbeddingResults.class)); + + SparseEmbeddingResults sparseEmbeddingResults = (SparseEmbeddingResults) results; + + List embeddingList = new ArrayList<>(); + List weightedTokens = new ArrayList<>(); + weightedTokens.add(new WeightedToken("6", 0.10137939453125f)); + weightedTokens.add(new WeightedToken("163040", 0.2841796875f)); + embeddingList.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + + for (int i = 0; i < embeddingList.size(); i++) { + assertThat(sparseEmbeddingResults.embeddings().get(i), is(embeddingList.get(i))); + } + } + + public void testFromRerankResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-46c1-8d69-e1e86d29aa9f", + "latency": 564.903929, + "usage": { + "doc_count": 2 + }, + "result": { + "scores":[ + { + "index":1, + "score": 1.37 + }, + { + "index":0, + "score": -0.3 + } + ] + } + } + """; + + Map jsonParserMap = new HashMap<>( + Map.of( + CustomServiceSettings.RERANK_PARSER_INDEX, + "$.result.scores[*].index", + CustomServiceSettings.RERANK_PARSER_SCORE, + "$.result.scores[*].score" + ) + ); + + var request = CustomRequestTests.createRequest(null, List.of("abc"), CustomModelTests.getTestModel(TaskType.RERANK, jsonParserMap)); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(RankedDocsResults.class)); + var expected = new ArrayList(); + expected.add(new RankedDocsResults.RankedDoc(1, 1.37F, null)); + expected.add(new RankedDocsResults.RankedDoc(0, -0.3F, null)); + + for (int i = 0; i < ((RankedDocsResults) results).getRankedDocs().size(); i++) { + assertThat(((RankedDocsResults) results).getRankedDocs().get(i).index(), is(expected.get(i).index())); + } + } + + public void testFromCompletionResponse() throws IOException { + String responseJson = """ + { + "request_id": "450fcb80-f796-****-8d69-e1e86d29aa9f", + "latency": 564.903929, + "result": { + "text":"completion results" + }, + "usage": { + "output_tokens": 6320, + "input_tokens": 35, + "total_tokens": 6355 + } + } + """; + + Map jsonParserMap = new HashMap<>(Map.of(CustomServiceSettings.COMPLETION_PARSER_RESULT, "$.result.text")); + + var request = CustomRequestTests.createRequest( + null, + List.of("abc"), + CustomModelTests.getTestModel(TaskType.COMPLETION, jsonParserMap) + ); + + InferenceServiceResults results = CustomResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(results, instanceOf(ChatCompletionResults.class)); + ChatCompletionResults chatCompletionResults = (ChatCompletionResults) results; + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("completion results")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java new file mode 100644 index 0000000000000..0ef4ee71e7622 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import io.netty.handler.codec.http.HttpMethod; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.custom.CustomUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class CustomModelTests extends ESTestCase { + public static String taskSettingsKey = "test_taskSettings_key"; + public static String taskSettingsValue = "test_taskSettings_value"; + + public static String secretSettingsKey = "test_secret_key"; + public static String secretSettingsValue = "test_secret_value"; + public static String url = "http://www.abc.com"; + public static String path = "/endpoint"; + + public void testOverride() { + var model = createModel( + "service", + TaskType.TEXT_EMBEDDING, + CustomServiceSettingsTests.createRandom(), + CustomTaskSettingsTests.createRandom(), + CustomSecretSettingsTests.createRandom() + ); + + var overriddenModel = CustomModel.of(model, Map.of()); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public static CustomModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets + ) { + return new CustomModel(modelId, taskType, CustomUtils.SERVICE_NAME, serviceSettings, taskSettings, secrets, null); + } + + public static CustomModel createModel( + String modelId, + TaskType taskType, + CustomServiceSettings serviceSettings, + CustomTaskSettings taskSettings, + @Nullable CustomSecretSettings secretSettings + ) { + return new CustomModel(modelId, taskType, CustomUtils.SERVICE_NAME, serviceSettings, taskSettings, secretSettings); + } + + public static CustomModel getTestModel() { + TaskType taskType = TaskType.TEXT_EMBEDDING; + Map jsonParserMap = new HashMap<>( + Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + return getTestModel(taskType, jsonParserMap); + } + + public static CustomModel getTestModel(TaskType taskType, Map jsonParserMap) { + // service settings + Integer dims = 1536; + Integer maxInputTokens = 512; + String description = "test fromMap"; + String version = "v1"; + String serviceType = taskType.toString(); + String method = HttpMethod.POST.name(); + String queryString = "?query=${" + taskSettingsKey + "}"; + Map headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}"); + String requestFormat = CustomServiceSettings.REQUEST_FORMAT_STRING; + String requestContentString = "\"input\":\"${input}\""; + + ResponseJsonParser responseJsonParser = new ResponseJsonParser(taskType, jsonParserMap, new ValidationException()); + + CustomServiceSettings serviceSettings = new CustomServiceSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + description, + version, + serviceType, + url, + path, + method, + queryString, + headers, + requestFormat, + null, + requestContentString, + responseJsonParser, + new RateLimitSettings(10_000) + ); + + // task settings + CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue), false); + + // secret settings + CustomSecretSettings secretSettings = new CustomSecretSettings(Map.of(secretSettingsKey, secretSettingsValue)); + + return CustomModelTests.createModel("service", taskType, serviceSettings, taskSettings, secretSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java new file mode 100644 index 0000000000000..357a558bd493f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettingsTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.hamcrest.Matchers.is; + +public class CustomSecretSettingsTests extends AbstractWireSerializingTestCase { + public static CustomSecretSettings createRandom() { + var secretParameters = randomBoolean() + ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5))) + : null; + return new CustomSecretSettings(secretParameters); + } + + public void testFromMap() { + Map secretParameters = new HashMap<>( + Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value"))) + ); + + MatcherAssert.assertThat( + CustomSecretSettings.fromMap(secretParameters), + is(new CustomSecretSettings(Map.of("test_key", "test_value"))) + ); + } + + public void testXContent() throws IOException { + var entity = new CustomSecretSettings(Map.of("test_key", "test_value")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("{\"secret_parameters\":{\"test_key\":\"test_value\"}}")); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomSecretSettings::new; + } + + @Override + protected CustomSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomSecretSettings mutateInstance(CustomSecretSettings instance) { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java new file mode 100644 index 0000000000000..92ad4d3dbd26e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -0,0 +1,291 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import io.netty.handler.codec.http.HttpMethod; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.hamcrest.Matchers.is; + +public class CustomServiceSettingsTests extends AbstractWireSerializingTestCase { + public static CustomServiceSettings createRandom(String inputUrl, String inputPath, String inputQueryString) { + List taskTypeStrList = Arrays.stream(TaskType.values()).map(TaskType::toString).toList(); + TaskType taskType = TaskType.fromString(randomFrom(taskTypeStrList)); + + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + var isTextEmbeddingModel = taskType.equals(TaskType.TEXT_EMBEDDING); + if (isTextEmbeddingModel) { + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1536; + } + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + String description = randomAlphaOfLength(15); + String version = randomAlphaOfLength(5); + String url = inputUrl != null ? inputUrl : randomAlphaOfLength(15); + String path = inputPath != null ? inputPath : randomAlphaOfLength(15); + String method = randomFrom(HttpMethod.PUT.name(), HttpMethod.POST.name(), HttpMethod.GET.name()); + String queryString = inputQueryString != null ? inputQueryString : randomAlphaOfLength(15); + Map headers = randomBoolean() ? null : Map.of("key", "value"); + String requestFormat = randomFrom(CustomServiceSettings.REQUEST_FORMAT_JSON, CustomServiceSettings.REQUEST_FORMAT_STRING); + Map requestContent = randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), randomAlphaOfLength(5))); + String requestContentString = randomBoolean() ? null : randomAlphaOfLength(10); + + Map textEmbeddingJsonParserMap = new HashMap<>( + Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + Map sparseEmbeddingJsonParserMap = new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_EMBEDDING_RESULT, + new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_RESULT_PATH, + "$.result.sparse_embeddings[*]", + CustomServiceSettings.SPARSE_RESULT_VALUE, + new HashMap<>( + Map.of( + CustomServiceSettings.SPARSE_EMBEDDING_PARSER_TOKEN, + "$.embedding[*].token_id", + CustomServiceSettings.SPARSE_EMBEDDING_PARSER_WEIGHT, + "$.embedding[*].weights" + ) + ) + ) + ) + ) + ); + Map rerankJsonParserMap = new HashMap<>( + Map.of( + CustomServiceSettings.RERANK_PARSER_INDEX, + "$.result.reranked_results[*].index", + CustomServiceSettings.RERANK_PARSER_SCORE, + "$.result.reranked_results[*].relevance_score", + CustomServiceSettings.RERANK_PARSER_DOCUMENT_TEXT, + "$.result.reranked_results[*].document_text" + ) + ); + Map completionJsonParserMap = new HashMap<>( + Map.of(CustomServiceSettings.COMPLETION_PARSER_RESULT, "$.result.text") + ); + + ResponseJsonParser responseJsonParser = switch (taskType) { + case TEXT_EMBEDDING -> new ResponseJsonParser(taskType, textEmbeddingJsonParserMap, new ValidationException()); + case SPARSE_EMBEDDING -> new ResponseJsonParser(taskType, sparseEmbeddingJsonParserMap, new ValidationException()); + case RERANK -> new ResponseJsonParser(taskType, rerankJsonParserMap, new ValidationException()); + case COMPLETION -> new ResponseJsonParser(taskType, completionJsonParserMap, new ValidationException()); + default -> null; + }; + + RateLimitSettings rateLimitSettings = new RateLimitSettings(randomLongBetween(1, 1000000)); + + return new CustomServiceSettings( + similarityMeasure, + dims, + maxInputTokens, + description, + version, + taskType.name(), + url, + path, + method, + queryString, + headers, + requestFormat, + requestContent, + requestContentString, + responseJsonParser, + rateLimitSettings + ); + } + + public static CustomServiceSettings createRandom() { + return createRandom(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5)); + } + + public void testFromMap() { + String similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + Integer dims = 1536; + Integer maxInputTokens = 512; + String description = "test fromMap"; + String version = "v1"; + String serviceType = TaskType.TEXT_EMBEDDING.toString(); + String url = "http://www.abc.com"; + String path = "/endpoint"; + String method = HttpMethod.POST.name(); + String queryString = "?query=test"; + Map headers = Map.of("key", "value"); + String requestFormat = CustomServiceSettings.REQUEST_FORMAT_STRING; + String requestContentString = "request body"; + + Map jsonParserMap = new HashMap<>( + Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + ResponseJsonParser responseJsonParser = new ResponseJsonParser(TaskType.TEXT_EMBEDDING, jsonParserMap, new ValidationException()); + + var settings = CustomServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + CustomServiceSettings.DESCRIPTION, + description, + CustomServiceSettings.VERSION, + version, + CustomServiceSettings.URL, + url, + CustomServiceSettings.PATH, + new HashMap<>( + Map.of( + path, + new HashMap<>( + Map.of( + method, + new HashMap<>( + Map.of( + CustomServiceSettings.QUERY_STRING, + queryString, + CustomServiceSettings.HEADERS, + headers, + CustomServiceSettings.REQUEST, + new HashMap<>( + Map.of( + CustomServiceSettings.REQUEST_FORMAT, + requestFormat, + CustomServiceSettings.REQUEST_CONTENT, + requestContentString + ) + ), + CustomServiceSettings.RESPONSE, + new HashMap<>( + Map.of( + CustomServiceSettings.JSON_PARSER, + new HashMap<>( + Map.of( + CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, + "$.result.embeddings[*].embedding" + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ), + null, + TaskType.TEXT_EMBEDDING + ); + + MatcherAssert.assertThat( + settings, + is( + new CustomServiceSettings( + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + description, + version, + serviceType, + url, + path, + method, + queryString, + headers, + requestFormat, + null, + requestContentString, + responseJsonParser, + new RateLimitSettings(10_000) + ) + ) + ); + } + + public void testXContent() throws IOException { + Map jsonParserMap = new HashMap<>( + Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding") + ); + + ResponseJsonParser responseJsonParser = new ResponseJsonParser(TaskType.TEXT_EMBEDDING, jsonParserMap, new ValidationException()); + + var entity = new CustomServiceSettings( + null, + null, + null, + "test fromMap", + "v1", + TaskType.TEXT_EMBEDDING.toString(), + "http://www.abc.com", + "/endpoint", + HttpMethod.POST.name(), + "?query=test", + Map.of("key", "value"), + CustomServiceSettings.REQUEST_FORMAT_STRING, + null, + "request body", + responseJsonParser, + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat( + xContentResult, + is( + "{\"description\":\"test fromMap\",\"version\":\"v1\"," + + "\"url\":\"http://www.abc.com\",\"path\":{\"/endpoint\":{\"POST\":{\"query_string\":\"?query=test\"," + + "\"headers\":{\"key\":\"value\"},\"request\":{\"format\":\"string\",\"content\":\"request body\"}," + + "\"response\":{\"json_parser\":{\"text_embeddings\":\"$.result.embeddings[*].embedding\"}}}}}," + + "\"rate_limit\":{\"requests_per_minute\":10000}}" + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomServiceSettings::new; + } + + @Override + protected CustomServiceSettings createTestInstance() { + return createRandom(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5)); + } + + @Override + protected CustomServiceSettings mutateInstance(CustomServiceSettings instance) { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java new file mode 100644 index 0000000000000..bd7db6410da48 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.custom; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.core.Tuple.tuple; +import static org.hamcrest.Matchers.is; + +public class CustomTaskSettingsTests extends AbstractWireSerializingTestCase { + public static CustomTaskSettings createRandom() { + var parameters = randomBoolean() ? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5))) : null; + var ignorePlaceholderCheck = randomBoolean(); + return new CustomTaskSettings(parameters, ignorePlaceholderCheck); + } + + public void testFromMap() { + Map taskSettingsMap = new HashMap<>( + Map.of( + CustomTaskSettings.PARAMETERS, + new HashMap<>(Map.of("test_key", "test_value")), + CustomTaskSettings.IGNORE_PLACEHOLDER_CHECK, + true + ) + ); + + MatcherAssert.assertThat( + CustomTaskSettings.fromMap(taskSettingsMap), + is(new CustomTaskSettings(Map.of("test_key", "test_value"), true)) + ); + } + + public void testXContent() throws IOException { + var entity = new CustomTaskSettings(Map.of("test_key", "test_value"), true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("{\"parameters\":{\"test_key\":\"test_value\"},\"ignore_placeholder_check\":true}")); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomTaskSettings::new; + } + + @Override + protected CustomTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomTaskSettings mutateInstance(CustomTaskSettings instance) throws IOException { + return null; + } + + public static Map getTaskSettingsMap( + @Nullable Map parameters, + @Nullable Boolean ignorePlaceholderCheck + ) { + var map = new HashMap(); + if (parameters != null) { + map.put(CustomTaskSettings.PARAMETERS, parameters); + } + if (ignorePlaceholderCheck != null) { + map.put(CustomTaskSettings.IGNORE_PLACEHOLDER_CHECK, ignorePlaceholderCheck); + } + + return map; + } +}